def get_prediction_module(self, bert_model, features, is_training, percent_done): n_classes = len(self._get_label_mapping()) reprs = bert_model.get_sequence_output() reprs = pretrain_helpers.gather_positions( reprs, features[self.name + "_labeled_positions"]) seq_lengths = tf.cast( tf.reduce_sum(features[self.name + "_labels_mask"], axis=1), tf.int32) logits = tf.layers.dense(reprs, n_classes) with tf.variable_scope("crf", reuse=tf.AUTO_REUSE): trans_val = tf.get_variable("transition", shape=[n_classes, n_classes], dtype=tf.float32) predict_ids, _ = crf.crf_decode(logits, trans_val, seq_lengths) actual_ids = features[self.name + "_labels"] log_likelihood, _ = crf.crf_log_likelihood( inputs=logits, tag_indices=actual_ids, sequence_lengths=seq_lengths, transition_params=trans_val) losses = -log_likelihood return losses, dict( loss=losses, logits=logits, predictions=predict_ids, labels=features[self.name + "_labels"], labels_mask=features[self.name + "_labels_mask"], labeled_positions=features[self.name + "_labeled_positions"], eid=features[self.name + "_eid"], )
def _get_masked_lm_output(self, inputs: pretrain_data.Inputs, model): """Masked language modeling softmax layer.""" masked_lm_weights = inputs.masked_lm_weights with tf.variable_scope("generator_predictions"): if self._config.uniform_generator: logits = tf.zeros(self._bert_config.vocab_size) logits_tiled = tf.zeros( modeling.get_shape_list(inputs.masked_lm_ids) + [self._bert_config.vocab_size]) logits_tiled += tf.reshape( logits, [1, 1, self._bert_config.vocab_size]) logits = logits_tiled else: relevant_hidden = pretrain_helpers.gather_positions( model.get_sequence_output(), inputs.masked_lm_positions) hidden = tf.layers.dense( relevant_hidden, units=modeling.get_shape_list( model.get_embedding_table())[-1], activation=modeling.get_activation( self._bert_config.hidden_act), kernel_initializer=modeling.create_initializer( self._bert_config.initializer_range), ) hidden = modeling.layer_norm(hidden) output_bias = tf.get_variable( "output_bias", shape=[self._bert_config.vocab_size], initializer=tf.zeros_initializer()) logits_embed = tf.matmul(hidden, model.get_embedding_table(), transpose_b=True) logits = tf.nn.bias_add(logits_embed, output_bias) oh_labels = tf.one_hot(inputs.masked_lm_ids, depth=self._bert_config.vocab_size, dtype=tf.float32) probs = tf.nn.softmax(logits) log_probs = tf.nn.log_softmax(logits) label_log_probs = -tf.reduce_sum(log_probs * oh_labels, axis=-1) numerator = tf.reduce_sum(inputs.masked_lm_weights * label_log_probs) denominator = tf.reduce_sum(masked_lm_weights) + 1e-6 loss = numerator / denominator preds = tf.argmax(log_probs, axis=-1, output_type=tf.int32) MLMOutput = collections.namedtuple( "MLMOutput", ["logits", "probs", "loss", "per_example_loss", "preds"]) return MLMOutput(logits=logits, probs=probs, per_example_loss=label_log_probs, loss=loss, preds=preds), logits_embed
def get_prediction_module(self, bert_model, features, is_training, percent_done): n_classes = len(self._get_label_mapping()) reprs = bert_model.get_sequence_output() reprs = pretrain_helpers.gather_positions( reprs, features[self.name + "_labeled_positions"]) logits = tf.layers.dense(reprs, n_classes) losses = tf.nn.softmax_cross_entropy_with_logits(labels=tf.one_hot( features[self.name + "_labels"], n_classes), logits=logits) losses *= features[self.name + "_labels_mask"] losses = tf.reduce_sum(losses, axis=-1) return losses, dict( loss=losses, logits=logits, predictions=tf.argmax(logits, axis=-1), labels=features[self.name + "_labels"], labels_mask=features[self.name + "_labels_mask"], eid=features[self.name + "_eid"], )
def _get_masked_lm_output(self, inputs: pretrain_data.Inputs, model): """Masked language modeling softmax layer.""" with tf.variable_scope("generator_predictions"): if self._config.uniform_generator: logits = tf.zeros(self._bert_config.vocab_size) logits_tiled = tf.zeros( modeling.get_shape_list(inputs.masked_lm_ids) + [self._bert_config.vocab_size]) logits_tiled += tf.reshape( logits, [1, 1, self._bert_config.vocab_size]) logits = logits_tiled else: relevant_reprs = pretrain_helpers.gather_positions( model.get_sequence_output(), inputs.masked_lm_positions) logits = get_token_logits(relevant_reprs, model.get_embedding_table(), self._bert_config) return get_softmax_output(logits, inputs.masked_lm_ids, inputs.masked_lm_weights, self._bert_config.vocab_size)
def _get_masked_lm_output(self, inputs: pretrain_data.Inputs, model): """Masked language modeling softmax layer.""" masked_lm_weights = inputs.masked_lm_weights with tf.variable_scope("generator_predictions"): if self._config.uniform_generator or self._config.identity_generator or self._config.heuristic_generator: logits = tf.zeros(self._bert_config.vocab_size) logits_tiled = tf.zeros( modeling.get_shape_list(inputs.masked_lm_ids) + [self._bert_config.vocab_size]) logits_tiled += tf.reshape(logits, [1, 1, self._bert_config.vocab_size]) logits = logits_tiled else: relevant_hidden = pretrain_helpers.gather_positions( model.get_sequence_output(), inputs.masked_lm_positions) hidden = tf.layers.dense( relevant_hidden, units=modeling.get_shape_list(model.get_embedding_table())[-1], activation=modeling.get_activation(self._bert_config.hidden_act), kernel_initializer=modeling.create_initializer( self._bert_config.initializer_range)) hidden = modeling.layer_norm(hidden) output_bias = tf.get_variable( "output_bias", shape=[self._bert_config.vocab_size], initializer=tf.zeros_initializer()) logits = tf.matmul(hidden, model.get_embedding_table(), transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) oh_labels = tf.one_hot( inputs.masked_lm_ids, depth=self._bert_config.vocab_size, dtype=tf.float32) probs = tf.nn.softmax(logits) if self._config.identity_generator: identity_logits = tf.zeros(self._bert_config.vocab_size) identity_logits_tiled = tf.zeros( modeling.get_shape_list(inputs.masked_lm_ids) + [self._bert_config.vocab_size]) masked_identity_weights = tf.one_hot(inputs.masked_lm_ids, depth=self._bert_config.vocab_size, dtype=tf.float32) identity_logits_tiled += 25.0 * masked_identity_weights identity_logits_tiled += tf.reshape(identity_logits, [1, 1, self._bert_config.vocab_size]) identity_logits = identity_logits_tiled identity_probs = tf.nn.softmax(identity_logits) identity_weight = (self.global_step / tf.cast(self._config.num_train_steps, tf.float32)) * self._config.max_identity_weight probs = probs * (1 - identity_weight) + identity_probs * identity_weight logits = tf.math.log(probs) # softmax(log(probs)) = probs elif self._config.heuristic_generator: synonym_logits = tf.zeros(self._bert_config.vocab_size) synonym_logits_tiled = tf.zeros( modeling.get_shape_list(inputs.masked_lm_ids) + [self._bert_config.vocab_size]) masked_synonym_weights = tf.reduce_sum( tf.one_hot(inputs.masked_synonym_ids, depth=self._bert_config.vocab_size, dtype=tf.float32), -2) padded_synonym_mask = tf.concat([tf.zeros([1]), tf.ones([self._bert_config.vocab_size - 1])], 0) masked_synonym_weights *= tf.expand_dims(tf.expand_dims(padded_synonym_mask, 0), 0) synonym_logits_tiled += 25.0 * masked_synonym_weights synonym_logits_tiled += tf.reshape(synonym_logits, [1, 1, self._bert_config.vocab_size]) synonym_logits = synonym_logits_tiled synonym_probs = tf.nn.softmax(synonym_logits) if self._config.synonym_scheduler_type == 'linear': synonym_weight = (self.global_step / tf.cast(self._config.num_train_steps, tf.float32)) * self._config.max_synonym_weight probs = probs * (1 - synonym_weight) + synonym_probs * synonym_weight logits = tf.math.log(probs) # softmax(log(probs)) = probs log_probs = tf.nn.log_softmax(logits) label_log_probs = -tf.reduce_sum(log_probs * oh_labels, axis=-1) numerator = tf.reduce_sum(inputs.masked_lm_weights * label_log_probs) denominator = tf.reduce_sum(masked_lm_weights) + 1e-6 loss = numerator / denominator preds = tf.argmax(log_probs, axis=-1, output_type=tf.int32) MLMOutput = collections.namedtuple( "MLMOutput", ["logits", "probs", "loss", "per_example_loss", "preds"]) return MLMOutput( logits=logits, probs=probs, per_example_loss=label_log_probs, loss=loss, preds=preds)
def __init__(self, config: configure_pretraining.PretrainingConfig, features, is_training): # Set up model config self._config = config self._bert_config = training_utils.get_bert_config(config) if config.debug: self._bert_config.num_hidden_layers = 3 self._bert_config.hidden_size = 144 self._bert_config.intermediate_size = 144 * 4 self._bert_config.num_attention_heads = 4 # Mask the input unmasked_inputs = pretrain_data.features_to_inputs(features) masked_inputs = pretrain_helpers.mask(config, unmasked_inputs, config.mask_prob) # Generator embedding_size = (self._bert_config.hidden_size if config.embedding_size is None else config.embedding_size) cloze_output = None if config.uniform_generator: # simple generator sampling fakes uniformly at random mlm_output = self._get_masked_lm_output(masked_inputs, None) elif ((config.electra_objective or config.electric_objective) and config.untied_generator): generator_config = get_generator_config(config, self._bert_config) if config.two_tower_generator: # two-tower cloze model generator used for electric generator = TwoTowerClozeTransformer(config, generator_config, unmasked_inputs, is_training, embedding_size) cloze_output = self._get_cloze_outputs(unmasked_inputs, generator) mlm_output = get_softmax_output( pretrain_helpers.gather_positions( cloze_output.logits, masked_inputs.masked_lm_positions), masked_inputs.masked_lm_ids, masked_inputs.masked_lm_weights, self._bert_config.vocab_size) else: # small masked language model generator generator = build_transformer( config, masked_inputs, is_training, generator_config, embedding_size=(None if config.untied_generator_embeddings else embedding_size), untied_embeddings=config.untied_generator_embeddings, scope="generator") mlm_output = self._get_masked_lm_output( masked_inputs, generator) else: # full-sized masked language model generator if using BERT objective or if # the generator and discriminator have tied weights generator = build_transformer(config, masked_inputs, is_training, self._bert_config, embedding_size=embedding_size) mlm_output = self._get_masked_lm_output(masked_inputs, generator) fake_data = self._get_fake_data(masked_inputs, mlm_output.logits) self.mlm_output = mlm_output self.total_loss = config.gen_weight * (cloze_output.loss if config.two_tower_generator else mlm_output.loss) # Discriminator disc_output = None if config.electra_objective or config.electric_objective: discriminator = build_transformer( config, fake_data.inputs, is_training, self._bert_config, reuse=not config.untied_generator, embedding_size=embedding_size) disc_output = self._get_discriminator_output( fake_data.inputs, discriminator, fake_data.is_fake_tokens, cloze_output) self.total_loss += config.disc_weight * disc_output.loss # Evaluation eval_fn_inputs = { "input_ids": masked_inputs.input_ids, "masked_lm_preds": mlm_output.preds, "mlm_loss": mlm_output.per_example_loss, "masked_lm_ids": masked_inputs.masked_lm_ids, "masked_lm_weights": masked_inputs.masked_lm_weights, "input_mask": masked_inputs.input_mask } if config.electra_objective or config.electric_objective: eval_fn_inputs.update({ "disc_loss": disc_output.per_example_loss, "disc_labels": disc_output.labels, "disc_probs": disc_output.probs, "disc_preds": disc_output.preds, "sampled_tokids": tf.argmax(fake_data.sampled_tokens, -1, output_type=tf.int32) }) eval_fn_keys = eval_fn_inputs.keys() eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys] def metric_fn(*args): """Computes the loss and accuracy of the model.""" d = {k: arg for k, arg in zip(eval_fn_keys, args)} metrics = dict() metrics["masked_lm_accuracy"] = tf.metrics.accuracy( labels=tf.reshape(d["masked_lm_ids"], [-1]), predictions=tf.reshape(d["masked_lm_preds"], [-1]), weights=tf.reshape(d["masked_lm_weights"], [-1])) metrics["masked_lm_loss"] = tf.metrics.mean( values=tf.reshape(d["mlm_loss"], [-1]), weights=tf.reshape(d["masked_lm_weights"], [-1])) if config.electra_objective or config.electric_objective: metrics["sampled_masked_lm_accuracy"] = tf.metrics.accuracy( labels=tf.reshape(d["masked_lm_ids"], [-1]), predictions=tf.reshape(d["sampled_tokids"], [-1]), weights=tf.reshape(d["masked_lm_weights"], [-1])) if config.disc_weight > 0: metrics["disc_loss"] = tf.metrics.mean(d["disc_loss"]) metrics["disc_auc"] = tf.metrics.auc( d["disc_labels"] * d["input_mask"], d["disc_probs"] * tf.cast(d["input_mask"], tf.float32)) metrics["disc_accuracy"] = tf.metrics.accuracy( labels=d["disc_labels"], predictions=d["disc_preds"], weights=d["input_mask"]) metrics["disc_precision"] = tf.metrics.accuracy( labels=d["disc_labels"], predictions=d["disc_preds"], weights=d["disc_preds"] * d["input_mask"]) metrics["disc_recall"] = tf.metrics.accuracy( labels=d["disc_labels"], predictions=d["disc_preds"], weights=d["disc_labels"] * d["input_mask"]) return metrics self.eval_metrics = (metric_fn, eval_fn_values)