示例#1
0
			def metric_fn(per_example_loss,
						logits, 
						label_ids):
				"""Computes the loss and accuracy of the model."""
				sentence_log_probs = tf.reshape(
					logits, [-1, logits.shape[-1]])
				sentence_predictions = tf.argmax(
					logits, axis=-1, output_type=tf.int32)
				sentence_labels = tf.reshape(label_ids, [-1])
				sentence_accuracy = tf.metrics.accuracy(
					labels=label_ids, predictions=sentence_predictions)
				sentence_mean_loss = tf.metrics.mean(
					values=per_example_loss)
				sentence_f = tf_metrics.f1(label_ids, 
										sentence_predictions, 
										num_labels, 
										label_lst, average="macro")

				eval_metric_ops = {
									"f1": sentence_f,
									"loss": sentence_mean_loss,
									"acc":sentence_accuracy
								}

				return eval_metric_ops
示例#2
0
            def metric_fn(masked_lm_example_loss, masked_lm_log_probs,
                          masked_lm_ids, masked_lm_weights, per_example_loss,
                          logits, label_ids):
                """Computes the loss and accuracy of the model."""
                masked_lm_log_probs = tf.reshape(
                    masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {
                    "masked_lm_accuracy":
                    masked_lm_accuracy[-1],
                    "masked_lm_loss":
                    masked_lm_mean_loss[-1],
                    "sentence_f":
                    sentence_f[-1],
                    "sentence_loss":
                    sentence_mean_loss[-1],
                    "probabilities":
                    tf.exp(tf.nn.log_softmax(logits, name="softmax_tensor")),
                    "label_ids":
                    label_ids
                }

                return eval_metric_ops
示例#3
0
def eval_logtis(logits, features, num_labels):

    label_ids = tf.reshape(tf.cast(features['label_ids'], tf.int32), [-1])
    label_weights = tf.reshape(tf.cast(features['label_weights'], tf.int32),
                               [-1])

    sentence_predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)

    sentence_accuracy = tf.metrics.accuracy(labels=label_ids,
                                            predictions=sentence_predictions,
                                            weights=label_weights)

    sentence_f = tf_metrics.f1(label_ids,
                               sentence_predictions,
                               num_labels,
                               weights=label_weights,
                               average="macro")

    eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

    return eval_metric_ops
示例#4
0
            def metric_fn(masked_lm_example_loss, masked_lm_log_probs,
                          masked_lm_ids, masked_lm_weights, per_example_loss,
                          logits, label_ids):
                """Computes the loss and accuracy of the model."""
                masked_lm_log_probs = tf.reshape(
                    masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "sentence_f": sentence_f,
                    "sentence_loss": sentence_mean_loss
                }
def discriminator_metric_eval(input_dict):

    d_out_real = input_dict['true_logits']
    d_out_fake = input_dict['fake_logits']

    input_shape_list = bert_utils.get_shape_list(d_out_real, expected_rank=[2])
    batch_size = input_shape_list[0]

    true_labels = tf.cast(tf.ones(batch_size), tf.int32)
    fake_labels = tf.cast(tf.zeros(batch_size), tf.int32)

    pred_true_label = tf.argmax(d_out_real, axis=-1)
    pred_fake_label = tf.argmax(d_out_fake, axis=-1)

    all_pred_label = tf.concat([pred_true_label, pred_fake_label], axis=0)
    all_true_label = tf.concat([true_labels, fake_labels], axis=0)

    if not kargs.get('use_tpu', True):
        discriminator_f1 = tf_metrics.f1(all_true_label,
                                         all_pred_label,
                                         2,
                                         average="macro")
        discriminator_precison = tf_metrics.precision(all_true_label,
                                                      all_pred_label,
                                                      2,
                                                      average="macro")
        discriminator_recall = tf_metrics.recall(all_true_label,
                                                 all_pred_label,
                                                 2,
                                                 average="macro")
        discriminator_f1_original = tf_metrics.f1(all_true_label,
                                                  all_pred_label,
                                                  2,
                                                  pos_indices=[0],
                                                  average="macro")
        discriminator_f1_replaced = tf_metrics.f1(all_true_label,
                                                  all_pred_label,
                                                  2,
                                                  pos_indices=[1],
                                                  average="macro")
        discriminator_precision_original = tf_metrics.precision(
            all_true_label,
            all_pred_label,
            2,
            pos_indices=[0],
            average="macro")
        discriminator_precision_replaced = tf_metrics.precision(
            all_true_label,
            all_pred_label,
            2,
            pos_indices=[1],
            average="macro")
        discriminator_recall_original = tf_metrics.recall(all_true_label,
                                                          all_pred_label,
                                                          2,
                                                          pos_indices=[0],
                                                          average="macro")
        discriminator_recall_replaced = tf_metrics.recall(all_true_label,
                                                          all_pred_label,
                                                          2,
                                                          pos_indices=[1],
                                                          average="macro")
        output_dict['discriminator_f1'] = discriminator_f1
        output_dict['discriminator_precison'] = discriminator_precison
        output_dict['discriminator_recall'] = discriminator_recall
        output_dict['discriminator_f1_original'] = discriminator_f1_original
        output_dict['discriminator_f1_replaced'] = discriminator_f1_replaced
        output_dict[
            'discriminator_precision_original'] = discriminator_precision_original
        output_dict[
            'discriminator_precision_replaced'] = discriminator_precision_replaced
        output_dict[
            'discriminator_recall_original'] = discriminator_recall_original
        output_dict[
            'discriminator_recall_replaced'] = discriminator_recall_replaced
    else:
        discriminator_recall = tf.compat.v1.metrics.recall(
            tf.one_hot(all_true_label, 2), tf.one_hot(all_pred_label, 2))

        discriminator_precison = tf.compat.v1.metrics.precision(
            tf.one_hot(all_true_label, 2), tf.one_hot(all_pred_label, 2))
        discriminator_f1 = tf_metrics.f1(all_true_label,
                                         all_pred_label,
                                         2,
                                         average="macro")
        discriminator_f1_original = tf_metrics.f1(all_true_label,
                                                  all_pred_label,
                                                  2,
                                                  pos_indices=[0],
                                                  average="macro")
        discriminator_f1_replaced = tf_metrics.f1(all_true_label,
                                                  all_pred_label,
                                                  2,
                                                  pos_indices=[1],
                                                  average="macro")
        discriminator_precision_original = tf_metrics.precision(
            all_true_label,
            all_pred_label,
            2,
            pos_indices=[0],
            average="macro")
        discriminator_precision_replaced = tf_metrics.precision(
            all_true_label,
            all_pred_label,
            2,
            pos_indices=[1],
            average="macro")
        discriminator_recall_original = tf_metrics.recall(all_true_label,
                                                          all_pred_label,
                                                          2,
                                                          pos_indices=[0],
                                                          average="macro")
        discriminator_recall_replaced = tf_metrics.recall(all_true_label,
                                                          all_pred_label,
                                                          2,
                                                          pos_indices=[1],
                                                          average="macro")

        output_dict['discriminator_f1_original'] = discriminator_f1_original
        output_dict['discriminator_f1_replaced'] = discriminator_f1_replaced
        output_dict[
            'discriminator_precision_original'] = discriminator_precision_original
        output_dict[
            'discriminator_precision_replaced'] = discriminator_precision_replaced
        output_dict[
            'discriminator_recall_original'] = discriminator_recall_original
        output_dict[
            'discriminator_recall_replaced'] = discriminator_recall_replaced
        output_dict['discriminator_f1'] = discriminator_f1
        output_dict['discriminator_precison'] = discriminator_precison
        output_dict['discriminator_recall'] = discriminator_recall
    return output_dict
示例#6
0
def discriminator_metric_eval(per_example_loss, logits, input_ids, sampled_ids,
					input_mask, **kargs):
	# original:0, replace:1
	discriminator_label_ids = tf.not_equal(
		tf.cast(input_ids, tf.int32),
		tf.cast(sampled_ids, tf.int32)
	)
	discriminator_label_ids = tf.cast(discriminator_label_ids, tf.int32)

	unk_mask = tf.cast(tf.math.equal(input_ids, 100), tf.float32) # not replace unk
	cls_mask =  tf.cast(tf.math.equal(input_ids, 101), tf.float32) # not replace cls
	sep_mask = tf.cast(tf.math.equal(input_ids, 102), tf.float32) # not replace sep

	none_replace_mask =  unk_mask + cls_mask + sep_mask

	input_mask = tf.cast(input_mask, tf.int32)
	input_mask *= tf.cast(1 - none_replace_mask, tf.int32) # cls, unk, sep are not considered as replace or original

	discriminator_lm_predictions = tf.argmax(
		logits, axis=-1, output_type=tf.int32)

	discriminator_label_ids = tf.reshape(discriminator_label_ids, [-1])
	discriminator_lm_predictions = tf.reshape(discriminator_lm_predictions, [-1])

	discriminator_mask = tf.reshape(input_mask, [-1])
	discriminator_accuracy = tf.metrics.accuracy(
		labels=discriminator_label_ids,
		predictions=discriminator_lm_predictions,
		weights=discriminator_mask)

	discriminator_per_example_loss = tf.reshape(per_example_loss, [-1])

	discriminator_mean_loss = tf.metrics.mean(
		values=discriminator_per_example_loss, 
		weights=discriminator_mask)

	output_dict = {
			"discriminator_accuracy":discriminator_accuracy,
			"discriminator_loss":discriminator_mean_loss
	}

	# recall, precision, f1 needs one-hot encoding
	if not kargs.get('use_tpu', True):
		discriminator_f1 = tf_metrics.f1(
										discriminator_label_ids,
										discriminator_lm_predictions,
										2, 
										weights=discriminator_mask, 
										average="macro")
		discriminator_precison = tf_metrics.precision(
										discriminator_label_ids, 
										discriminator_lm_predictions, 
										2, 
										weights=discriminator_mask, 
										average='macro')
		discriminator_recall = tf_metrics.recall(
										discriminator_label_ids,
										discriminator_lm_predictions,
										2, 
										weights=discriminator_mask, 
										average='macro')
		discriminator_f1_original = tf_metrics.f1(
										discriminator_label_ids,
										discriminator_lm_predictions,
										2, 
										weights=discriminator_mask,
										pos_indices=[0],
										average="macro")
		discriminator_f1_replaced = tf_metrics.f1(
										discriminator_label_ids,
										discriminator_lm_predictions,
										2, 
										weights=discriminator_mask,
										pos_indices=[1],
										average="macro")
		discriminator_precision_original = tf_metrics.precision(
										discriminator_label_ids,
										discriminator_lm_predictions,
										2, 
										weights=discriminator_mask,
										pos_indices=[0],
										average="macro")
		discriminator_precision_replaced = tf_metrics.precision(
										discriminator_label_ids,
										discriminator_lm_predictions,
										2, 
										weights=discriminator_mask,
										pos_indices=[1],
										average="macro")
		discriminator_recall_original = tf_metrics.recall(
										discriminator_label_ids,
										discriminator_lm_predictions,
										2, 
										weights=discriminator_mask,
										pos_indices=[0],
										average="macro")
		discriminator_recall_replaced = tf_metrics.recall(
										discriminator_label_ids,
										discriminator_lm_predictions,
										2, 
										weights=discriminator_mask,
										pos_indices=[1],
										average="macro")
		output_dict['discriminator_f1'] = discriminator_f1
		output_dict['discriminator_precison'] = discriminator_precison
		output_dict['discriminator_recall'] = discriminator_recall
		output_dict['discriminator_f1_original'] = discriminator_f1_original
		output_dict['discriminator_f1_replaced'] = discriminator_f1_replaced
		output_dict['discriminator_precision_original'] = discriminator_precision_original
		output_dict['discriminator_precision_replaced'] = discriminator_precision_replaced
		output_dict['discriminator_recall_original'] = discriminator_recall_original
		output_dict['discriminator_recall_replaced'] = discriminator_recall_replaced
	else:
		discriminator_recall = tf.compat.v1.metrics.recall(
										tf.one_hot(discriminator_label_ids, 2), 
										tf.one_hot(discriminator_lm_predictions, 2),
										weights=discriminator_mask)

		discriminator_precison = tf.compat.v1.metrics.precision(
										tf.one_hot(discriminator_label_ids, 2), 
										tf.one_hot(discriminator_lm_predictions, 2),
										weights=discriminator_mask)

		discriminator_f1 = tf_metrics.f1(
										discriminator_label_ids,
										discriminator_lm_predictions,
										2, 
										weights=discriminator_mask, 
										average="macro")
		discriminator_f1_original = tf_metrics.f1(
										discriminator_label_ids,
										discriminator_lm_predictions,
										2, 
										weights=discriminator_mask,
										pos_indices=[0],
										average="macro")
		discriminator_f1_replaced = tf_metrics.f1(
										discriminator_label_ids,
										discriminator_lm_predictions,
										2, 
										weights=discriminator_mask,
										pos_indices=[1],
										average="macro")
		discriminator_precision_original = tf_metrics.precision(
										discriminator_label_ids,
										discriminator_lm_predictions,
										2, 
										weights=discriminator_mask,
										pos_indices=[0],
										average="macro")
		discriminator_precision_replaced = tf_metrics.precision(
										discriminator_label_ids,
										discriminator_lm_predictions,
										2, 
										weights=discriminator_mask,
										pos_indices=[1],
										average="macro")
		discriminator_recall_original = tf_metrics.recall(
										discriminator_label_ids,
										discriminator_lm_predictions,
										2, 
										weights=discriminator_mask,
										pos_indices=[0],
										average="macro")
		discriminator_recall_replaced = tf_metrics.recall(
										discriminator_label_ids,
										discriminator_lm_predictions,
										2, 
										weights=discriminator_mask,
										pos_indices=[1],
										average="macro")

		output_dict['discriminator_f1_original'] = discriminator_f1_original
		output_dict['discriminator_f1_replaced'] = discriminator_f1_replaced
		output_dict['discriminator_precision_original'] = discriminator_precision_original
		output_dict['discriminator_precision_replaced'] = discriminator_precision_replaced
		output_dict['discriminator_recall_original'] = discriminator_recall_original
		output_dict['discriminator_recall_replaced'] = discriminator_recall_replaced
		output_dict['discriminator_f1'] = discriminator_f1
		output_dict['discriminator_precison'] = discriminator_precison
		output_dict['discriminator_recall'] = discriminator_recall
	return output_dict