예제 #1
0
def classifier(config, seq_output,
						input_ids,
						sampled_ids,
						input_mask,
						num_labels,
						dropout_prob,
						**kargs):

	"""
	input_ids: original input ids
	sampled_ids: generated fake ids
	"""

	output_layer = seq_output
	hidden_size = output_layer.shape[-1].value

	output_weights = tf.get_variable(
			"output_weights", [num_labels, hidden_size],
			initializer=tf.truncated_normal_initializer(stddev=0.02))

	output_bias = tf.get_variable(
			"output_bias", [num_labels], initializer=tf.zeros_initializer())

	if config.get('ln_type', 'postln') == 'preln':
		output_layer = albert_modules.layer_norm(output_layer)
	elif config.get('ln_type', 'postln') == 'postln':
		output_layer = output_layer
	else:
		output_layer = output_layer

	output_layer = tf.nn.dropout(output_layer, keep_prob=1 - dropout_prob)

	logits = tf.einsum("abc,dc->abd", seq_output, output_weights)
	logits = tf.nn.bias_add(logits, output_bias) # batch x seq_length x 2

	input_ids = tf.cast(input_ids, tf.int32)
	sampled_ids = tf.cast(sampled_ids, tf.int32)

	discriminator_label_ids = tf.cast(tf.equal(input_ids, sampled_ids), tf.int32)

	per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
												logits=logits,
												labels=tf.stop_gradient(discriminator_label_ids))
	loss = per_example_loss * tf.cast(input_mask, tf.float32)

	loss = tf.reduce_sum(loss) / (1e-10 + tf.reduce_sum(tf.cast(input_mask, tf.float32)))

	return (loss, logits, per_example_loss)
예제 #2
0
def get_next_sentence_output(config, input_tensor, labels, reuse=None):
	"""Get loss and log probs for the next sentence prediction."""
	# Simple binary classification. Note that 0 is "next sentence" and 1 is
	# "random sentence". This weight matrix is not used after pre-training.

	scope = kargs.get('scope', None)
	if scope:
		scope = scope + '/' + 'cls/seq_relationship'
	else:
		scope = 'cls/seq_relationship'
	tf.logging.info("**** nsp scope **** %s", str(scope))

	# with tf.variable_scope("cls/seq_relationship", reuse=reuse):
	with tf.variable_scope(scope, reuse=reuse):

		if config.get('ln_type', 'postln') == 'preln':
			input_tensor = albert_modules.layer_norm(input_tensor)
		elif config.get('ln_type', 'postln') == 'postln':
			input_tensor = input_tensor
		else:
			input_tensor = input_tensor

		output_weights = tf.get_variable(
				"output_weights",
				shape=[2, config.hidden_size],
				initializer=albert_modules.create_initializer(config.initializer_range))
		output_bias = tf.get_variable(
				"output_bias", shape=[2], initializer=tf.zeros_initializer())

		logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
		logits = tf.nn.bias_add(logits, output_bias)
		log_probs = tf.nn.log_softmax(logits, axis=-1)
		labels = tf.reshape(labels, [-1])
		one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
		per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
		loss = tf.reduce_mean(per_example_loss)
		return (loss, per_example_loss, log_probs)
예제 #3
0
def token_generator_gumbel_normal(config, input_tensor, output_weights,
                                  input_ids, input_ori_ids, input_mask,
                                  **kargs):

    input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape_list[0]
    seq_length = input_shape_list[1]
    hidden_dims = input_shape_list[2]

    embedding_projection = kargs.get('embedding_projection', None)

    scope = kargs.get('scope', None)
    if scope:
        scope = scope + '/' + 'cls/predictions'
    else:
        scope = 'cls/predictions'

    tf.logging.info("**** mlm generator scope **** %s", str(scope))

    # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        if config.get('ln_type', 'postln') == 'preln':
            input_tensor = albert_modules.layer_norm(input_tensor)
        elif config.get('ln_type', 'postln') == 'postln':
            input_tensor = input_tensor
        else:
            input_tensor = input_tensor

        # if config.get("embedding", "factorized") == "factorized":
        # 	projection_width = config.hidden_size
        # else:
        # 	projection_width = config.embedding_size

        if config.get("embedding", "none_factorized") == "none_factorized":
            projection_width = config.hidden_size
            tf.logging.info("==not using embedding factorized==")
        else:
            projection_width = config.get('embedding_size', config.hidden_size)
            tf.logging.info(
                "==using embedding factorized: embedding size: %s==",
                str(projection_width))

        with tf.variable_scope("transform"):
            input_tensor = tf.layers.dense(
                input_tensor,
                units=projection_width,
                activation=albert_modules.get_activation(config.hidden_act),
                kernel_initializer=albert_modules.create_initializer(
                    config.initializer_range))

            if config.get('ln_type', 'postln') == 'preln':
                input_tensor = input_tensor
            elif config.get('ln_type', 'postln') == 'postln':
                input_tensor = albert_modules.layer_norm(input_tensor)
            else:
                input_tensor = albert_modules.layer_norm(input_tensor)

        if embedding_projection is not None:
            # batch x seq x hidden, embedding x hidden
            print(input_tensor.get_shape(), embedding_projection.get_shape())
            input_tensor = tf.einsum("abc,dc->abd", input_tensor,
                                     embedding_projection)
        else:
            print("==no need for embedding projection==")
            input_tensor = input_tensor

        output_bias = tf.get_variable("output_bias",
                                      shape=[config.vocab_size],
                                      initializer=tf.zeros_initializer())
        # batch x seq x embedding
        logits = tf.einsum("abc,dc->abd", input_tensor, output_weights)
        logits = tf.nn.bias_add(logits, output_bias)

        input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3)
        width = input_shape_list[2]

        logits_tempered = tf.nn.log_softmax(logits, axis=-1)

        # width=config.vocab_size
        flat_logits_tempered = tf.reshape(logits_tempered,
                                          [batch_size * seq_length, width])

        num_train_steps = kargs.get('num_train_steps', None)
        if num_train_steps and kargs.get('gumbel_anneal',
                                         "anneal") == 'anneal':
            tf.logging.info("****** apply annealed temperature ******* %s",
                            str(num_train_steps))
            annealed_temp = tf.train.polynomial_decay(
                config.get('gumbel_temperature', 1.0),
                tf.train.get_or_create_global_step(),
                kargs.get("num_train_steps", 10000),
                end_learning_rate=0.1,
                power=1.0,
                cycle=False)
        elif kargs.get('gumbel_anneal', "anneal") == 'softplus':
            tf.logging.info("****** apply auto-scale temperature *******")
            # batch x seq x dim
            with tf.variable_scope("gumbel_auto_scaling_temperature"):
                annealed_temp = tf.layers.dense(
                    input_tensor,
                    1,
                    activation=tf.nn.softplus,
                ) + 1.0
                annealed_temp = 1. / annealed_temp
                annealed_temp = tf.reshape(annealed_temp,
                                           [batch_size * seq_length, 1])
            if config.get('gen_sample', 1) > 1:
                tf.logging.info(
                    "****** apply auto-scale temperature for multi-sampling *******"
                )
                annealed_temp = tf.expand_dims(annealed_temp, -1)
        else:
            annealed_temp = 1.0
            tf.logging.info(
                "****** not apply annealed tenperature with fixed temp ******* %s",
                str(annealed_temp))

        # [batch x seq] x config.vocab_size x config.get('gen_sample', 1)
        sampled_logprob_temp, sampled_logprob = gumbel_softmax(
            flat_logits_tempered,
            temperature=annealed_temp,
            samples=config.get('gen_sample', 1))

        # argmax on config.vocab_size which is always axis=1
        # [batch x seq] x config.vocab_size x config.get('gen_sample', 1)
        # armax(logits+gumbel_samples) to sample a categoritical distribution
        if kargs.get('sampled_prob_id', True):
            tf.logging.info(
                "****** apply categorical sampled id of original logits *******"
            )
            sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob, axis=1),
                                         config.vocab_size,
                                         axis=1)  # sampled multiminal id
        else:
            tf.logging.info(
                "****** apply gumbel-softmax logprob for logits *******")
            sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob_temp,
                                                   axis=1),
                                         config.vocab_size,
                                         axis=1)  # sampled multiminal id

        # straight-through gumbel softmax estimator
        if kargs.get("straight_through", True):
            tf.logging.info(
                "****** apply straight_through_estimator without grl *******")
            sampled_id = tf.stop_gradient(sampled_hard_id -
                                          sampled_logprob_temp) + (
                                              sampled_logprob_temp)
        else:
            tf.logging.info(
                "****** apply gumbel-softmax probs without grl *******")
            sampled_id = flip_gradient(sampled_logprob_temp)

        sampled_binary_mask = kargs.get('sampled_binary_mask', None)
        if sampled_binary_mask is not None:
            label_diff_ids = tf.identity(
                sampled_binary_mask)  # 0 for original and 1 for replace
        else:
            label_diff_ids = tf.not_equal(
                tf.cast(input_ids, tf.int32),
                tf.cast(input_ori_ids,
                        tf.int32)  # 0 for original and 1 for replace
            )

        label_diff_ids = tf.cast(label_diff_ids, tf.float32)

        label_diff_ids = tf.expand_dims(label_diff_ids,
                                        axis=[-1])  # batch x seq x 1
        input_ori_ids = tf.one_hot(input_ori_ids,
                                   config.vocab_size)  # batch x seq x vocab
        input_ori_ids = tf.cast(input_ori_ids, tf.float32)

        if config.get('gen_sample', 1) == 1:
            sampled_input_id = tf.reshape(
                sampled_id, [batch_size, seq_length, config.vocab_size])
            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                label_diff_ids = tf.cast(label_diff_ids, tf.float32)
                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id, tf.float32
                ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32)
        else:
            sampled_input_id = tf.reshape(samples, [
                batch_size, seq_length, config.vocab_size,
                config.get('gen_sample', 1)
            ])
            label_diff_ids = tf.expand_dims(label_diff_ids,
                                            axis=-1)  # batch x seq x 1
            input_ori_ids = tf.expand_dims(input_ori_ids,
                                           axis=-1)  # batch x seq x vocab x 1

            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id,
                    tf.float32) + (1 - input_ori_ids) * label_diff_ids

        tf.logging.info("====generator use_tpu %s ====",
                        str(kargs.get('use_tpu', True)))
        if not kargs.get('use_tpu', True):
            tf.logging.info("====logging generator loss ====")
            sampled_not_equal_id = tf.not_equal(
                tf.cast(tf.argmax(sampled_input_id, axis=2), tf.int32),
                tf.cast(tf.argmax(input_ori_ids, axis=2), tf.int32))
            sampled_not_equal = tf.cast(sampled_not_equal_id,
                                        tf.float32) * tf.cast(
                                            input_mask, tf.float32)
            sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (
                1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32)))
            tf.summary.scalar('generator_sample_acc', sampled_not_equal)

        return sampled_input_id
예제 #4
0
def token_generator_gumbel(config, input_tensor, output_weights, input_ids,
                           input_ori_ids, input_mask, **kargs):

    input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape_list[0]
    seq_length = input_shape_list[1]
    hidden_dims = input_shape_list[2]

    embedding_projection = kargs.get('embedding_projection', None)

    scope = kargs.get('scope', None)
    if scope:
        scope = scope + '/' + 'cls/predictions'
    else:
        scope = 'cls/predictions'

    tf.logging.info("**** mlm generator scope **** %s", str(scope))

    # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        if config.get('ln_type', 'postln') == 'preln':
            input_tensor = albert_modules.layer_norm(input_tensor)
        elif config.get('ln_type', 'postln') == 'postln':
            input_tensor = input_tensor
        else:
            input_tensor = input_tensor

        # if config.get("embedding", "factorized") == "factorized":
        # 	projection_width = config.hidden_size
        # else:
        # 	projection_width = config.embedding_size

        if config.get("embedding", "none_factorized") == "none_factorized":
            projection_width = config.hidden_size
            tf.logging.info("==not using embedding factorized==")
        else:
            projection_width = config.get('embedding_size', config.hidden_size)
            tf.logging.info(
                "==using embedding factorized: embedding size: %s==",
                str(projection_width))

        with tf.variable_scope("transform"):
            input_tensor = tf.layers.dense(
                input_tensor,
                units=projection_width,
                activation=albert_modules.get_activation(config.hidden_act),
                kernel_initializer=albert_modules.create_initializer(
                    config.initializer_range))

            if config.get('ln_type', 'postln') == 'preln':
                input_tensor = input_tensor
            elif config.get('ln_type', 'postln') == 'postln':
                input_tensor = albert_modules.layer_norm(input_tensor)
            else:
                input_tensor = albert_modules.layer_norm(input_tensor)

        if embedding_projection is not None:
            # batch x seq x hidden, embedding x hidden
            print(input_tensor.get_shape(), embedding_projection.get_shape())
            input_tensor = tf.einsum("abc,dc->abd", input_tensor,
                                     embedding_projection)
        else:
            print("==no need for embedding projection==")
            input_tensor = input_tensor

        output_bias = tf.get_variable("output_bias",
                                      shape=[config.vocab_size],
                                      initializer=tf.zeros_initializer())
        # batch x seq x embedding
        logits = tf.einsum("abc,dc->abd", input_tensor, output_weights)
        logits = tf.nn.bias_add(logits, output_bias)

        input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3)
        width = input_shape_list[2]

        # it seems no need for logits to be normalized
        logits_tempered = logits  #tf.nn.log_softmax(logits, axis=-1)

        # width=config.vocab_size
        flat_logits_tempered = tf.reshape(logits_tempered,
                                          [batch_size * seq_length, width])

        num_train_steps = kargs.get('num_train_steps', None)
        if num_train_steps and kargs.get('gumbel_anneal',
                                         "anneal") == 'anneal':
            tf.logging.info("****** apply annealed temperature ******* %s",
                            str(num_train_steps))
            temperature_warmup_steps = int(num_train_steps) * 0.1
            annealed_temp = tf.train.polynomial_decay(
                config.get('gumbel_temperature', 1.0),
                tf.train.get_or_create_global_step(),
                temperature_warmup_steps,
                end_learning_rate=0.01,
                power=1.0,
                cycle=False)
            gumbel_samples = None
            if not kargs.get('use_tpu', True):
                tf.summary.scalar('annealed_temp', annealed_temp)
        elif kargs.get('gumbel_anneal', "anneal") == 'softplus':
            tf.logging.info("****** apply auto-scale temperature *******")
            # batch x seq x dim
            with tf.variable_scope("gumbel_auto_scaling_temperature"):
                annealed_temp = tf.layers.dense(
                    input_tensor,
                    1,
                    activation=tf.nn.softplus,
                ) + 1.0
                annealed_temp = 1. / annealed_temp
                annealed_temp = tf.reshape(annealed_temp,
                                           [batch_size * seq_length, 1])
            if not kargs.get('use_tpu', True):
                tf.summary.scalar('softplus temperature',
                                  tf.reduce_mean(annealed_temp))
            if config.get('gen_sample', 1) > 1:
                tf.logging.info(
                    "****** apply auto-scale temperature for multi-sampling *******"
                )
                annealed_temp = tf.expand_dims(annealed_temp, -1)
            gumbel_samples = None
        elif kargs.get('gumbel_anneal', 'vqvae') == 'vqvae':
            temperature_warmup_steps = kargs.get("num_train_steps",
                                                 10000) * 0.1
            tf.logging.info(
                "****** apply t2t gumbel-softmax temperature annealing method with warm up steps %s ******* ",
                str(kargs.get("num_train_steps", 10000) * 0.1))
            steps = temperature_warmup_steps
            gumbel_samples = sample_gumbel(bert_utils.get_shape_list(
                flat_logits_tempered, expected_rank=2),
                                           samples=config.get('gen_sample', 1))
            gumbel_samples *= inverse_exp_decay(steps // 5) * 0.5
            annealed_temp_decay = 1.01 - inverse_lin_decay(
                steps)  # minimum temperature is set 0.2
            annealed_temp = tf.cond(
                tf.less(tf.random_uniform([]), 0.9),
                lambda: annealed_temp_decay, lambda: tf.random_uniform(
                    [], minval=0.5, maxval=1.0))  # 10% step for
            tf.logging.info(
                "****** apply t2t gumbel-softmax temperature annealing method ******* "
            )
            if not kargs.get('use_tpu', True):
                tf.summary.scalar('t2t_vqvae_stgs temperature', annealed_temp)
                tf.summary.scalar('t2t_vqvae_stgs temperature decay',
                                  annealed_temp_decay)
        elif kargs.get('gumbel_anneal', 'vqvae_v1') == 'vqvae_v1':
            temperature_warmup_steps = kargs.get("num_train_steps",
                                                 10000) * 0.1
            tf.logging.info(
                "****** apply t2t gumbel-softmax temperature annealing method with warm up steps %s ******* ",
                str(kargs.get("num_train_steps", 10000) * 0.1))
            steps = temperature_warmup_steps
            gumbel_samples = sample_gumbel(bert_utils.get_shape_list(
                flat_logits_tempered, expected_rank=2),
                                           samples=config.get('gen_sample', 1))
            # gumbel_samples *= inverse_exp_decay(steps)
            annealed_temp_decay = 1.01 - inverse_exp_decay(
                kargs.get("num_train_steps",
                          10000))  # minimum temperature is set 0.2
            annealed_temp = annealed_temp_decay
            # annealed_temp = tf.cond(
            # 			tf.less(tf.random_uniform([]), 0.95), lambda: annealed_temp_decay,
            # 			lambda: tf.random_uniform([], minval=0.5, maxval=1.0)) # 10% step for
            tf.logging.info(
                "****** apply sel-gan gumbel-softmax temperature annealing method ******* "
            )
            if not kargs.get('use_tpu', True):
                tf.summary.scalar('t2t_vqvae_stgs temperature', annealed_temp)
                tf.summary.scalar('t2t_vqvae_stgs temperature decay',
                                  annealed_temp_decay)
        elif kargs.get('gumbel_anneal', 'vqvae_v2') == 'vqvae_v2':
            temperature_warmup_steps = kargs.get("num_train_steps",
                                                 10000) * 0.1
            tf.logging.info(
                "****** apply t2t gumbel-softmax temperature annealing method with warm up steps %s ******* ",
                str(kargs.get("num_train_steps", 10000) * 0.1))
            steps = temperature_warmup_steps
            gumbel_samples = sample_gumbel(bert_utils.get_shape_list(
                flat_logits_tempered, expected_rank=2),
                                           samples=config.get('gen_sample', 1))
            # gumbel_samples *= inverse_exp_decay(steps)
            annealed_temp_decay = inverse_temp_exp_decay(
                kargs.get("num_train_steps", 10000),
                kargs.get("max_temp", 100))
            annealed_temp = annealed_temp_decay
            # annealed_temp = tf.cond(
            # 			tf.less(tf.random_uniform([]), 0.95), lambda: annealed_temp_decay,
            # 			lambda: tf.random_uniform([], minval=0.5, maxval=1.0)) # 10% step for
            tf.logging.info(
                "****** apply sel-gan-v2 gumbel-softmax temperature annealing method ******* "
            )
            tf.logging.info(
                "****** apply sel-gan-v2 gumbel-softmax num_train_steps:%s annealing method, temp:%s ******* ",
                str(kargs.get("num_train_steps", 10000)),
                str(kargs.get("max_temp", 100)))
            if not kargs.get('use_tpu', True):
                tf.summary.scalar('t2t_vqvae_stgs temperature', annealed_temp)
                tf.summary.scalar('t2t_vqvae_stgs temperature decay',
                                  annealed_temp_decay)
        else:
            annealed_temp = 0.01
            gumbel_samples = None
            tf.logging.info(
                "****** not apply annealed tenperature with fixed temp ******* %s",
                str(annealed_temp))
            if not kargs.get('use_tpu', True):
                tf.summary.scalar('gumbel_temperature', annealed_temp)

        # [batch x seq] x config.vocab_size x config.get('gen_sample', 1)
        if kargs.get('stable_gradient', True):
            sampled_logprob_temp, sampled_logprob = gumbel_softmax(
                flat_logits_tempered,
                temperature=annealed_temp,
                gumbel_samples=gumbel_samples,
                samples=config.get('gen_sample', 1),
                greedy=kargs.get("greedy", False))
            tf.logging.info(
                "****** apply normal derivate for gradient calculation *******"
            )
        else:
            sampled_logprob_temp, sampled_logprob = gumbel_softmax_custom_grad(
                flat_logits_tempered,
                temperature=annealed_temp,
                gumbel_samples=gumbel_samples,
                samples=config.get('gen_sample', 1))
            tf.logging.info(
                "****** apply log deriviate for stable gradient calculation *******"
            )

        # argmax on config.vocab_size which is always axis=1
        # [batch x seq] x config.vocab_size x config.get('gen_sample', 1)
        # armax(logits+gumbel_samples) to sample a categoritical distribution
        if kargs.get('sampled_prob_id', False):
            tf.logging.info(
                "****** apply categorical sampled id of original logits *******"
            )
            sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob, axis=1),
                                         config.vocab_size,
                                         axis=1)  # sampled multiminal id
        else:
            tf.logging.info(
                "****** apply gumbel-softmax logprob for logits *******")
            sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob_temp,
                                                   axis=1),
                                         config.vocab_size,
                                         axis=1)  # sampled multiminal id

        # straight-through gumbel softmax estimator
        if kargs.get('if_flip_grad', True):
            tf.logging.info("****** apply gradient flipping *******")
            sampled_logprob_temp_1 = flip_gradient(sampled_logprob_temp)
        else:
            tf.logging.info("****** not apply gradient flipping *******")
            sampled_logprob_temp_1 = sampled_logprob_temp
        if kargs.get("straight_through", True):
            tf.logging.info("****** apply straight_through_estimator *******")
            sampled_id = tf.stop_gradient(sampled_hard_id -
                                          sampled_logprob_temp) + (
                                              sampled_logprob_temp_1)
        else:
            tf.logging.info("****** apply gumbel-softmax probs *******")
            sampled_id = sampled_logprob_temp_1

        sampled_binary_mask = kargs.get('sampled_binary_mask', None)
        if sampled_binary_mask is not None:
            label_diff_ids = tf.identity(
                sampled_binary_mask)  # 0 for original and 1 for replace
        else:
            label_diff_ids = tf.not_equal(
                tf.cast(input_ids, tf.int32),
                tf.cast(input_ori_ids,
                        tf.int32)  # 0 for original and 1 for replace
            )

        label_diff_ids = tf.cast(label_diff_ids, tf.float32)

        label_diff_ids = tf.expand_dims(label_diff_ids,
                                        axis=[-1])  # batch x seq x 1
        input_ori_ids_1 = input_ori_ids
        input_ori_ids = tf.one_hot(input_ori_ids,
                                   config.vocab_size)  # batch x seq x vocab
        input_ori_ids = tf.cast(input_ori_ids, tf.float32)

        if config.get('gen_sample', 1) == 1:
            sampled_input_id = tf.reshape(
                sampled_id, [batch_size, seq_length, config.vocab_size])
            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                label_diff_ids = tf.cast(label_diff_ids, tf.float32)
                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id, tf.float32
                ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32)
            elif kargs.get('mask_method', 'only_mask') == 'all_mask':
                unk_mask = tf.cast(tf.math.equal(input_ori_ids_1, 100),
                                   tf.float32)  # not replace unk
                cls_mask = tf.cast(tf.math.equal(input_ori_ids_1, 101),
                                   tf.float32)  # not replace cls
                sep_mask = tf.cast(tf.math.equal(input_ori_ids_1, 102),
                                   tf.float32)  # not replace sep
                unsampled_mask = (1 -
                                  (unk_mask + cls_mask + sep_mask)) * tf.cast(
                                      input_mask, tf.float32)
                unsampled_mask = tf.expand_dims(unsampled_mask,
                                                axis=[-1])  # batch x seq x 1
                ori_input_mask = tf.expand_dims(input_mask,
                                                axis=[-1])  # batch x seq x 1

                tf.logging.info("****** all mask sample *******")
                sampled_input_id = unsampled_mask * tf.cast(
                    sampled_input_id,
                    tf.float32) + (1 - unsampled_mask) * tf.cast(
                        ori_input_mask, tf.float32) * tf.cast(
                            input_ori_ids, tf.float32)
        else:
            sampled_input_id = tf.reshape(samples, [
                batch_size, seq_length, config.vocab_size,
                config.get('gen_sample', 1)
            ])
            label_diff_ids = tf.expand_dims(label_diff_ids,
                                            axis=-1)  # batch x seq x 1
            input_ori_ids = tf.expand_dims(input_ori_ids,
                                           axis=-1)  # batch x seq x vocab x 1

            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id,
                    tf.float32) + (1 - input_ori_ids) * label_diff_ids

        tf.logging.info("====generator use_tpu %s ====",
                        str(kargs.get('use_tpu', True)))
        if not kargs.get('use_tpu', True):
            tf.logging.info("====logging generator loss ====")
            sampled_not_equal_id = tf.not_equal(
                tf.cast(tf.argmax(sampled_input_id, axis=2), tf.int32),
                tf.cast(tf.argmax(input_ori_ids, axis=2), tf.int32))
            sampled_equal_id = tf.equal(
                tf.cast(tf.argmax(sampled_input_id, axis=2), tf.int32),
                tf.cast(tf.argmax(input_ori_ids, axis=2), tf.int32))

            sampled_not_equal = tf.cast(sampled_not_equal_id,
                                        tf.float32) * tf.cast(
                                            input_mask, tf.float32)

            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32)))
                label_diff_ids_my = tf.cast(label_diff_ids, tf.float32)
            elif kargs.get('mask_method', 'only_mask') == 'all_mask':
                sampled_equal = tf.cast(
                    sampled_equal_id, tf.float32) * tf.cast(
                        tf.squeeze(unsampled_mask, axis=-1), tf.float32)
                tf.summary.scalar('generator_equal_sample_acc',
                                  tf.reduce_sum(sampled_equal))
                label_diff_ids_my = tf.cast(unsampled_mask, tf.float32)
                sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32)))
                sampled_equal = tf.reduce_sum(sampled_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32)))
            tf.summary.scalar('generator_sample_acc', sampled_not_equal)
            tf.summary.scalar("generator_valid_token",
                              tf.reduce_sum(input_mask))

            sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob_temp,
                                                   axis=1),
                                         config.vocab_size,
                                         axis=1)  # sampled multiminal id
            sampled_hard_id = tf.cast(sampled_hard_id, tf.float32)
            sampled_hard_id = tf.reshape(
                sampled_hard_id, [batch_size, seq_length, config.vocab_size])

            sampled_soft_id = tf.reshape(
                sampled_id, [batch_size, seq_length, config.vocab_size])
            sampled_hard_id *= label_diff_ids_my
            sampled_soft_id *= label_diff_ids_my

            hard_soft_bias = tf.reduce_sum(
                tf.sqrt(
                    tf.reduce_sum(tf.pow(sampled_hard_id - sampled_soft_id, 2),
                                  axis=-1))) / (1e-10 + tf.reduce_sum(
                                      tf.cast(label_diff_ids_my, tf.float32)))
            tf.summary.scalar('soft_hard_bias', hard_soft_bias)

        return sampled_input_id
예제 #5
0
def classifier(config, seq_output, input_ids, sampled_ids, input_mask,
               num_labels, dropout_prob, **kargs):
    """
	input_ids: original input ids
	sampled_ids: generated fake ids
	"""
    output_layer = seq_output
    hidden_size = output_layer.shape[-1].value

    unk_mask = tf.cast(tf.math.equal(input_ids, 100),
                       tf.float32)  # not replace unk
    cls_mask = tf.cast(tf.math.equal(input_ids, 101),
                       tf.float32)  # not replace cls
    sep_mask = tf.cast(tf.math.equal(input_ids, 102),
                       tf.float32)  # not replace sep

    none_replace_mask = unk_mask + cls_mask + sep_mask

    input_mask = tf.cast(input_mask, tf.int32)
    input_mask *= tf.cast(
        1 - none_replace_mask,
        tf.int32)  # cls, unk, sep are not considered as replace or original

    output_weights = tf.get_variable(
        "output_weights", [num_labels, hidden_size],
        initializer=tf.truncated_normal_initializer(stddev=0.02))

    output_bias = tf.get_variable("output_bias", [num_labels],
                                  initializer=tf.zeros_initializer())

    if config.get('ln_type', 'postln') == 'preln':
        output_layer = albert_modules.layer_norm(output_layer)
        print('====preln transformer====')
    elif config.get('ln_type', 'postln') == 'postln':
        output_layer = output_layer
        print('====postln transformer====')
    else:
        output_layer = output_layer
        print('====no layer layer_norm====')

    output_layer = tf.nn.dropout(output_layer, keep_prob=1 - dropout_prob)

    logits = tf.einsum("abc,dc->abd", output_layer, output_weights)
    logits = tf.nn.bias_add(logits, output_bias)  # batch x seq_length x 2

    input_ids = tf.cast(input_ids, tf.int32)

    input_shape_list = bert_utils.get_shape_list(sampled_ids,
                                                 expected_rank=[2, 3])
    if len(input_shape_list) == 3:
        tmp_sampled_ids = tf.argmax(sampled_ids,
                                    axis=-1)  # batch x seq x vocab
        tmp_sampled_ids = tf.cast(tmp_sampled_ids, tf.int32)
        tf.logging.info("****** gumbel 3-D sampled_ids *******")
    elif len(input_shape_list) == 2:
        tmp_sampled_ids = sampled_ids
        tmp_sampled_ids = tf.cast(tmp_sampled_ids, tf.int32)
        tf.logging.info("****** normal 2-D sampled_ids *******")

    ori_sampled_ids = kargs.get('ori_sampled_ids', None)
    if ori_sampled_ids is not None:
        input_shape_list = bert_utils.get_shape_list(ori_sampled_ids,
                                                     expected_rank=[2, 3])
        if len(input_shape_list) == 3:
            tmp_ori_sampled_ids = tf.argmax(ori_sampled_ids,
                                            axis=-1)  # batch x seq x vocab
            tmp_ori_sampled_ids = tf.cast(tmp_sampled_ori_ids, tf.int32)
            tf.logging.info("****** gumbel 3-D sampled_ids *******")
        elif len(input_shape_list) == 2:
            tmp_ori_sampled_ids = tf.cast(ori_sampled_ids, tf.int32)
            tf.logging.info("****** normal 2-D sampled_ids *******")

        masked_not_equal_mask = tf.cast(
            tf.not_equal(input_ids, tmp_ori_sampled_ids), tf.int32)
        masked_not_equal_mask *= tf.cast(input_mask, tf.int32)
    else:
        masked_not_equal_mask = None
    if masked_not_equal_mask is not None:
        tf.logging.info(
            "****** loss mask using masked token mask for masked tokens *******"
        )
        loss_mask = masked_not_equal_mask
    else:
        tf.logging.info(
            "****** loss mask using input_mask for all tokens *******")
        loss_mask = input_mask

    # original:0, replace:1
    not_equal_label_ids = tf.cast(tf.not_equal(input_ids, tmp_sampled_ids),
                                  tf.int32)
    not_equal_label_ids *= tf.cast(input_mask, tf.int32)

    if kargs.get('loss', 'cross_entropy') == 'cross_entropy':
        per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=tf.stop_gradient(not_equal_label_ids))
    elif kargs.get('loss', 'cross_entropy') == 'focal_loss':
        input_shape_list = bert_utils.get_shape_list(input_ids,
                                                     expected_rank=2)
        batch_size = input_shape_list[0]
        seq_length = input_shape_list[1]
        not_equal_label_ids_ = tf.reshape(not_equal_label_ids,
                                          [batch_size * seq_length])
        logits_ = tf.reshape(logits, [batch_size * seq_length, -1])
        per_example_loss, _ = loss_utils.focal_loss_binary_v2(
            config, logits_, not_equal_label_ids_)
        per_example_loss = tf.reshape(per_example_loss,
                                      [batch_size, seq_length])

    # loss = per_example_loss * tf.cast(loss_mask, tf.float32)
    # loss = tf.reduce_sum(loss) / (1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32)))

    equal_label_ids = (1 - tf.cast(not_equal_label_ids, tf.float32)) * tf.cast(
        loss_mask, tf.float32)
    equal_loss = tf.reduce_sum(per_example_loss * equal_label_ids)

    equal_loss_output = equal_loss / (1e-10 + tf.reduce_sum(equal_label_ids))

    not_equal_loss = tf.reduce_sum(
        per_example_loss *
        tf.cast(not_equal_label_ids, tf.float32))  # not equal:1, equal:0
    not_equal_loss_output = not_equal_loss / (
        1e-10 + tf.reduce_sum(tf.cast(not_equal_label_ids, tf.float32)))

    loss = (equal_loss + not_equal_loss) / (
        1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32)))

    tf.logging.info("====discriminator classifier use_tpu %s ====",
                    str(kargs.get('use_tpu', True)))
    if not kargs.get('use_tpu', True):
        tf.logging.info("====logging discriminator loss ====")
        tf.summary.scalar('mask_based_loss', loss)

        tf.summary.scalar(
            'equal_loss', equal_loss /
            (1e-10 + tf.reduce_sum(tf.cast(input_mask, tf.float32))))

        tf.summary.scalar(
            'not_equal_loss', not_equal_loss /
            (1e-10 + tf.reduce_sum(tf.cast(input_mask, tf.float32))))

        tf.summary.scalar(
            'loss_decomposition', loss - (equal_loss + not_equal_loss) /
            (1e-10 + tf.reduce_sum(tf.cast(input_mask, tf.float32))))

    return (loss, logits, per_example_loss)
예제 #6
0
def get_masked_lm_output(config, input_tensor, output_weights, positions,
							label_ids, label_weights,
							**kargs):

	reuse = kargs.get('reuse', False)
	embedding_projection = kargs.get('embedding_projection', None)
	"""Get loss and log probs for the masked LM."""
	input_tensor = tf.cast(input_tensor, tf.float32)
	positions = tf.cast(positions, tf.int32)
	label_ids = tf.cast(label_ids, tf.int32)
	label_weights = tf.cast(label_weights, tf.float32)

	input_tensor = bert_utils.gather_indexes(input_tensor, positions)
	"""
	flatten masked lm ids with positions
	"""

	scope = kargs.get('scope', None)
	if scope:
		scope = scope + '/' + 'cls/predictions'
	else:
		scope = 'cls/predictions'

	tf.logging.info("**** mlm scope **** %s", str(scope))

	# with tf.variable_scope("cls/predictions", reuse=reuse):
	with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
		# We apply one more non-linear transformation before the output layer.
		# This matrix is not used after pre-training.
		if config.get('ln_type', 'postln') == 'preln':
			input_tensor = albert_modules.layer_norm(input_tensor)
		elif config.get('ln_type', 'postln') == 'postln':
			input_tensor = input_tensor
		else:
			input_tensor = input_tensor

		if config.get("embedding", "factorized") == "factorized":
			projection_width = config.hidden_size
		else:
			projection_width = config.embedding_size

		with tf.variable_scope("transform"):
			input_tensor = tf.layers.dense(
					input_tensor,
					units=projection_width,
					activation=albert_modules.get_activation(config.hidden_act),
					kernel_initializer=albert_modules.create_initializer(
							config.initializer_range))
			if config.get('ln_type', 'postln') == 'preln':
				input_tensor = input_tensor
			elif config.get('ln_type', 'postln') == 'postln':
				input_tensor = albert_modules.layer_norm(input_tensor)
			else:
				input_tensor = albert_modules.layer_norm(input_tensor)

		if embedding_projection is not None:
			input_tensor = tf.matmul(input_tensor, 
								embedding_projection,
								transpose_b=True)
		else:
			print("==no need for embedding projection==")
			input_tensor = input_tensor

		# The output weights are the same as the input embeddings, but there is
		# an output-only bias for each token.
		output_bias = tf.get_variable(
				"output_bias",
				shape=[config.vocab_size],
				initializer=tf.zeros_initializer())
		logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
		# logits = tf.multiply(logits,
		# 						1.0 / math.sqrt(float(config.hidden_size)))
		# logits *= 2
		
		logits = tf.nn.bias_add(logits, output_bias)
		log_probs = tf.nn.log_softmax(logits, axis=-1)

		label_ids = tf.reshape(label_ids, [-1])
		label_weights = tf.reshape(label_weights, [-1])

		# one_hot_labels = tf.one_hot(
		# 		label_ids, depth=config.vocab_size, dtype=tf.float32)

		per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
													labels=tf.stop_gradient(label_ids),
													logits=logits)
		# per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])

		numerator = tf.reduce_sum(label_weights * per_example_loss)
		denominator = tf.reduce_sum(label_weights) + 1e-5

		# The `positions` tensor might be zero-padded (if the sequence is too
		# short to have the maximum number of predictions). The `label_weights`
		# tensor has a value of 1.0 for every real prediction and 0.0 for the
		# padding predictions.
		# per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
		# numerator = tf.reduce_sum(label_weights * per_example_loss)
		# denominator = tf.reduce_sum(label_weights) + 1e-5
		loss = numerator / denominator

	return (loss, per_example_loss, log_probs, label_weights)
예제 #7
0
def transformer_cell(input_tensor,
                     attention_mask=None,
                     hidden_size=768,
                     num_hidden_layers=12,
                     num_attention_heads=12,
                     intermediate_size=3072,
                     intermediate_act_fn=gelu,
                     hidden_dropout_prob=0.1,
                     attention_probs_dropout_prob=0.1,
                     initializer_range=0.02,
                     do_return_all_layers=False,
                     shared_type=None,
                     adapter_fn=None):

    layer_input = bert_utils.reshape_to_matrix(input_tensor)
    with tf.variable_scope("layer_shared", reuse=tf.AUTO_REUSE):
        with tf.variable_scope("attention", reuse=tf.AUTO_REUSE):
            attention_heads = []
            with tf.variable_scope("self"):
                [attention_head,
                 attention_scores] = albert_modules.attention_layer(
                     from_tensor=layer_input,
                     to_tensor=layer_input,
                     attention_mask=attention_mask,
                     num_attention_heads=num_attention_heads,
                     size_per_head=attention_head_size,
                     attention_probs_dropout_prob=attention_probs_dropout_prob,
                     initializer_range=initializer_range,
                     do_return_2d_tensor=True,
                     batch_size=batch_size,
                     from_seq_length=seq_length,
                     to_seq_length=seq_length)
                attention_heads.append(attention_head)
                all_attention_scores.append(attention_scores)

            attention_output = None
            if len(attention_heads) == 1:
                attention_output = attention_heads[0]
            else:
                # In the case where we have other sequences, we just concatenate
                # them to the self-attention head before the projection.
                attention_output = tf.concat(attention_heads, axis=-1)

            # Run a linear projection of `hidden_size` then add a residual
            # with `layer_input`.
            with tf.variable_scope("output", reuse=tf.AUTO_REUSE):
                attention_output = tf.layers.dense(
                    attention_output,
                    hidden_size,
                    kernel_initializer=albert_modules.create_initializer(
                        initializer_range))
                attention_output = albert_modules.dropout(
                    attention_output, hidden_dropout_prob)

                if adapter_fn:
                    attention_output = adapter_fn(attention_output,
                                                  layer_idx=layer_idx)

                attention_output = albert_modules.layer_norm(attention_output +
                                                             layer_input)

        # The activation is only applied to the "intermediate" hidden layer.
        with tf.variable_scope('intermediate', reuse=tf.AUTO_REUSE):
            intermediate_output = tf.layers.dense(
                attention_output,
                intermediate_size,
                activation=intermediate_act_fn,
                kernel_initializer=albert_modules.create_initializer(
                    initializer_range))

        # Down-project back to `hidden_size` then add the residual.
        with tf.variable_scope('output', reuse=tf.AUTO_REUSE):
            layer_output = tf.layers.dense(
                intermediate_output,
                hidden_size,
                kernel_initializer=albert_modules.create_initializer(
                    initializer_range))
            layer_output = albert_modules.dropout(layer_output,
                                                  hidden_dropout_prob)

        if adapter_fn:
            layer_output = adapter_fn(attention_output, layer_idx=layer_idx)

        layer_output = albert_modules.layer_norm(layer_output +
                                                 attention_output)
예제 #8
0
def classifier(config,
               pooled_output,
               num_labels,
               labels,
               dropout_prob,
               ratio_weight=None,
               **kargs):

    output_layer = pooled_output

    hidden_size = output_layer.shape[-1].value

    output_weights = tf.get_variable(
        "output_weights", [num_labels, hidden_size],
        initializer=tf.truncated_normal_initializer(stddev=0.02))

    output_bias = tf.get_variable("output_bias", [num_labels],
                                  initializer=tf.zeros_initializer())

    if config.get('ln_type', 'postln') == 'preln':
        output_layer = albert_modules.layer_norm(output_layer)
    elif config.get('ln_type', 'postln') == 'postln':
        output_layer = output_layer
    else:
        output_layer = output_layer

    output_layer = tf.nn.dropout(output_layer, keep_prob=1 - dropout_prob)

    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)

    if config.get("label_type", "single_label") == "single_label":
        if config.get("loss", "entropy") == "entropy":
            print("==standard cross entropy==")
            tf.logging.info("****** loss type ******* %s", "entropy")
            per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=logits, labels=tf.stop_gradient(labels))
        elif config.get("loss", "entropy") == "focal_loss":
            print("==multi_label focal loss==")
            per_example_loss, _ = loss_utils.focal_loss_multi_v1(config,
                                                                 logits=logits,
                                                                 labels=labels)
        elif config.get("loss", "entropy") == "dmi_loss":
            tf.logging.info("****** loss type ******* %s", "dmi_loss")
            loss, per_example_loss = loss_utils.dmi_loss(config,
                                                         logits=logits,
                                                         labels=labels,
                                                         **kargs)

        try:
            per_example_loss = loss_utils.weighted_loss_ratio(
                config, per_example_loss, labels, ratio_weight)
            loss = tf.reduce_sum(per_example_loss)
            print(" == applying weighted loss == ")
        except:
            if config.get("loss", "entropy") in ["entropy", "focal_loss"]:
                loss = tf.reduce_mean(per_example_loss)
            elif config.get("loss", "entropy") == "dmi_loss":
                tf.logging.info(
                    "****** dmi loss need no further calculation ******* ")
                loss = loss

        if config.get("with_center_loss", "no") == "center_loss":
            print("==apply with center loss==")
            center_loss, _ = loss_utils.center_loss_v2(config,
                                                       features=pooled_output,
                                                       labels=labels)
            loss += center_loss * config.get("center_loss_coef", 1e-3)

        return (loss, per_example_loss, logits)
    elif config.get("label_type", "single_label") == "multi_label":
        # logits = tf.log_sigmoid(logits)
        per_example_loss = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=logits,
            labels=tf.stop_gradient(tf.cast(labels, tf.float32)))
        per_example_loss = tf.reduce_sum(per_example_loss, axis=-1)
        loss = tf.reduce_mean(per_example_loss)
        return (loss, per_example_loss, logits)
    else:
        raise NotImplementedError()
예제 #9
0
def token_generator(config, input_tensor, output_weights, input_ids,
                    input_ori_ids, input_mask, **kargs):

    input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape_list[0]
    seq_length = input_shape_list[1]
    hidden_dims = input_shape_list[2]

    embedding_projection = kargs.get('embedding_projection', None)

    scope = kargs.get('scope', None)
    if scope:
        scope = scope + '/' + 'cls/predictions'
    else:
        scope = 'cls/predictions'

    tf.logging.info("**** mlm generator scope **** %s", str(scope))

    # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        if config.get('ln_type', 'postln') == 'preln':
            input_tensor = albert_modules.layer_norm(input_tensor)
        elif config.get('ln_type', 'postln') == 'postln':
            input_tensor = input_tensor
        else:
            input_tensor = input_tensor

        # if config.get("embedding", "factorized") == "factorized":
        # 	projection_width = config.hidden_size
        # else:
        # 	projection_width = config.embedding_size

        if config.get("embedding", "none_factorized") == "none_factorized":
            projection_width = config.hidden_size
            tf.logging.info("==not using embedding factorized==")
        else:
            projection_width = config.get('embedding_size', config.hidden_size)
            tf.logging.info(
                "==using embedding factorized: embedding size: %s==",
                str(projection_width))

        with tf.variable_scope("transform"):
            input_tensor = tf.layers.dense(
                input_tensor,
                units=projection_width,
                activation=albert_modules.get_activation(config.hidden_act),
                kernel_initializer=albert_modules.create_initializer(
                    config.initializer_range))

            if config.get('ln_type', 'postln') == 'preln':
                input_tensor = input_tensor
            elif config.get('ln_type', 'postln') == 'postln':
                input_tensor = albert_modules.layer_norm(input_tensor)
            else:
                input_tensor = albert_modules.layer_norm(input_tensor)

        if embedding_projection is not None:
            # batch x seq x hidden, embedding x hidden
            print(input_tensor.get_shape(), embedding_projection.get_shape())
            input_tensor = tf.einsum("abc,dc->abd", input_tensor,
                                     embedding_projection)
        else:
            print("==no need for embedding projection==")
            input_tensor = input_tensor

        output_bias = tf.get_variable("output_bias",
                                      shape=[config.vocab_size],
                                      initializer=tf.zeros_initializer())
        # batch x seq x embedding
        logits = tf.einsum("abc,dc->abd", input_tensor, output_weights)
        logits = tf.nn.bias_add(logits, output_bias)

        input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3)
        width = input_shape_list[2]

        if not kargs.get("apply_valid_vocab", False):
            logits = logits
            tf.logging.info("****** normal logits *******")
        elif kargs.get("apply_valid_vocab", False) == 'topk':
            prob, _ = top_k_softmax(logits, kargs.get('topk', 10))
            logits = tf.log(prob + 1e-10)
            tf.logging.info("****** topk logits *******")
        else:
            invalid_size = kargs.get("invalid_size", 106)
            invalid_mask = tf.cast(
                tf.ones((1, invalid_size)) * (-10000), tf.float32)
            valid_mask = tf.cast(
                tf.zeros((1, config.vocab_size - invalid_size)), tf.float32)
            invaild_mask = tf.concat([invalid_mask, valid_mask], axis=-1)
            #
            invaild_mask = tf.expand_dims(invaild_mask,
                                          axis=1)  # batch x seq x vocab
            logits += tf.cast(invaild_mask, tf.float32)
            tf.logging.info(
                "****** only valid logits ******* , invalid size: %s",
                str(invalid_size))

        logits_tempered = tf.nn.log_softmax(logits /
                                            config.get("temperature", 1.0))

        flat_logits_tempered = tf.reshape(logits_tempered,
                                          [batch_size * seq_length, width])

        # flat_logits_tempered_topk = top_k_logits(flat_logits_tempered, int(config.vocab_size/2
        if not kargs.get("greedy", False):
            sampled_logprob_temp, sampled_logprob = gumbel_softmax(
                flat_logits_tempered,
                temperature=1.0,
                samples=config.get('gen_sample', 1),
                greedy=kargs.get("greedy", False))

            samples = tf.argmax(sampled_logprob, axis=1)  # batch x seq
            tf.logging.info("****** normal sample *******")
        else:
            samples = tf.argmax(flat_logits_tempered, axis=-1)
            tf.logging.info("****** greedy sample *******")

        # samples = tf.multinomial(flat_logits_tempered,
        # 						num_samples=config.get('gen_sample', 1),
        # 						output_dtype=tf.int32)

        sampled_binary_mask = kargs.get('sampled_binary_mask', None)
        if sampled_binary_mask is not None:
            label_diff_ids = sampled_binary_mask  # 0 for original and 1 for replace
        else:
            label_diff_ids = tf.not_equal(
                tf.cast(input_ids, tf.int32),
                tf.cast(input_ori_ids,
                        tf.int32)  # 0 for original and 1 for replace
            )
        label_diff_ids = tf.cast(label_diff_ids, tf.float32)
        print(label_diff_ids, "===label diff ids===")
        if not kargs.get('use_tpu', True):
            tf.summary.scalar(
                'label_diff_ids',
                tf.reduce_sum(label_diff_ids * tf.cast(input_mask, tf.float32))
                / tf.reduce_sum(tf.cast(input_mask, tf.float32)))

        if config.get('gen_sample', 1) == 1:
            sampled_input_id = tf.reshape(samples, [batch_size, seq_length])
            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                label_diff_ids = tf.cast(label_diff_ids, tf.float32)
                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id, tf.float32
                ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32)
                sampled_input_id = tf.cast(sampled_input_id, tf.int32)
            elif kargs.get('mask_method', 'only_mask') == 'all_mask':
                input_ori_ids_1 = input_ori_ids
                unk_mask = tf.cast(tf.math.equal(input_ori_ids_1, 100),
                                   tf.float32)  # not replace unk
                cls_mask = tf.cast(tf.math.equal(input_ori_ids_1, 101),
                                   tf.float32)  # not replace cls
                sep_mask = tf.cast(tf.math.equal(input_ori_ids_1, 102),
                                   tf.float32)  # not replace sep
                unsampled_mask = (1 -
                                  (unk_mask + cls_mask + sep_mask)) * tf.cast(
                                      input_mask, tf.float32)
                # unsampled_mask = tf.expand_dims(unsampled_mask, axis=[-1]) # batch x seq x 1
                tf.logging.info("****** all mask sample *******")
                sampled_input_id = unsampled_mask * tf.cast(
                    sampled_input_id, tf.float32
                ) + (1 - unsampled_mask) * tf.cast(input_ori_ids, tf.float32)
                sampled_input_id = tf.cast(sampled_input_id, tf.int32)
        else:
            sampled_input_id = tf.reshape(
                samples, [batch_size, seq_length,
                          config.get('gen_sample', 1)])
            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                # batch x seq_length x 1
                label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1)
                label_diff_ids = tf.einsum(
                    'abc,cd->abd', label_diff_ids,
                    tf.ones((1, model_config.get('gen_sample', 1))))
                # batch x seq_length x 1
                input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1)
                input_ori_ids = tf.einsum(
                    'abc,cd->abd', input_ori_ids,
                    tf.ones((1, model_config.get('gen_sample', 1))))
                input_ori_ids = tf.cast(input_ori_ids, tf.float32)

                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id,
                    tf.float32) + (1 - input_ori_ids) * label_diff_ids
                sampled_input_id = tf.cast(sampled_input_id, tf.int32)

                input_mask = tf.expand_dims(input_mask, axis=-1)
                input_mask = tf.einsum(
                    'abc,cd->abd', input_mask,
                    tf.ones((1, model_config.get('gen_sample', 1))))
                input_mask = tf.cast(input_mask, tf.float32)

        if not kargs.get('use_tpu', True):

            sampled_not_equal_id = tf.not_equal(
                tf.cast(sampled_input_id, tf.int32),
                tf.cast(input_ori_ids, tf.int32))
            sampled_not_equal = tf.cast(sampled_not_equal_id,
                                        tf.float32) * tf.cast(
                                            input_mask, tf.float32)

            sampled_equal_id = tf.equal(tf.cast(sampled_input_id, tf.int32),
                                        tf.cast(input_ori_ids, tf.int32))

            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32)))
            elif kargs.get('mask_method', 'only_mask') == 'all_mask':
                sampled_equal = tf.cast(sampled_equal_id,
                                        tf.float32) * tf.cast(
                                            unsampled_mask, tf.float32)
                tf.summary.scalar('generator_equal_sample_acc',
                                  tf.reduce_sum(sampled_equal))
                sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32)))
                sampled_equal = tf.reduce_sum(sampled_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32)))
            tf.summary.scalar('generator_sample_acc', sampled_not_equal)

        # sampled_not_equal_id = tf.not_equal(
        # 		tf.cast(sampled_input_id, tf.int32),
        # 		tf.cast(input_ori_ids, tf.int32)
        # )
        # sampled_not_equal = tf.cast(sampled_not_equal_id, tf.float32) * tf.cast(input_mask, tf.float32)
        # sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32)))

        # if not kargs.get('use_tpu', True):
        # 	tf.summary.scalar('generator_sample_acc',
        # 					sampled_not_equal)

        return sampled_input_id
예제 #10
0
def seq_mask_masked_lm_output(config, input_tensor, output_weights,
							input_mask, input_ori_ids, input_ids, 
							sampled_binary_mask, **kargs):

	input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
	batch_size = input_shape_list[0]
	seq_length = input_shape_list[1]
	hidden_dims = input_shape_list[2]

	embedding_projection = kargs.get('embedding_projection', None)

	scope = kargs.get('scope', None)
	if scope:
		scope = scope + '/' + 'cls/predictions'
	else:
		scope = 'cls/predictions'

	tf.logging.info("**** mlm generator scope **** %s", str(scope))

	# with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
	with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
		if config.get('ln_type', 'postln') == 'preln':
			input_tensor = albert_modules.layer_norm(input_tensor)
		elif config.get('ln_type', 'postln') == 'postln':
			input_tensor = input_tensor
		else:
			input_tensor = input_tensor

		if config.get("embedding", "factorized") == "factorized":
			projection_width = config.hidden_size
		else:
			projection_width = config.embedding_size

		with tf.variable_scope("transform"):
			input_tensor = tf.layers.dense(
					input_tensor,
					units=projection_width,
					activation=albert_modules.get_activation(config.hidden_act),
					kernel_initializer=albert_modules.create_initializer(
							config.initializer_range))

			if config.get('ln_type', 'postln') == 'preln':
				input_tensor = input_tensor
			elif config.get('ln_type', 'postln') == 'postln':
				input_tensor = albert_modules.layer_norm(input_tensor)
			else:
				input_tensor = albert_modules.layer_norm(input_tensor)

		if embedding_projection is not None:
			# batch x seq x hidden, embedding x hidden
			print(input_tensor.get_shape(), embedding_projection.get_shape())
			input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection)
		else:
			print("==no need for embedding projection==")
			input_tensor = input_tensor

		output_bias = tf.get_variable(
				"output_bias",
				shape=[config.vocab_size],
				initializer=tf.zeros_initializer())
		# batch x seq x embedding
		logits = tf.einsum("abc,dc->abd", input_tensor, output_weights)
		logits = tf.nn.bias_add(logits, output_bias)

		"""
		if input_ori_ids[i] is random pertubated, sampled_binary_mask[i]=1
		"""
		sampled_binary_mask = tf.cast(sampled_binary_mask, tf.float32)
		per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
												logits=logits,
												labels=tf.stop_gradient(input_ori_ids),
												)
		per_example_loss *= sampled_binary_mask
		loss = tf.reduce_sum(per_example_loss) / (1e-10 + tf.reduce_sum(sampled_binary_mask))

		return (loss, per_example_loss, logits, sampled_binary_mask)
예제 #11
0
def token_generator(config, input_tensor, output_weights, input_ids,
                    input_ori_ids, input_mask, **kargs):

    input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape_list[0]
    seq_length = input_shape_list[1]
    hidden_dims = input_shape_list[2]

    embedding_projection = kargs.get('embedding_projection', None)

    scope = kargs.get('scope', None)
    if scope:
        scope = scope + '/' + 'cls/predictions'
    else:
        scope = 'cls/predictions'

    tf.logging.info("**** mlm generator scope **** %s", str(scope))

    # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        if config.get('ln_type', 'postln') == 'preln':
            input_tensor = albert_modules.layer_norm(input_tensor)
        elif config.get('ln_type', 'postln') == 'postln':
            input_tensor = input_tensor
        else:
            input_tensor = input_tensor

        if config.get("embedding", "factorized") == "factorized":
            projection_width = config.hidden_size
        else:
            projection_width = config.embedding_size

        with tf.variable_scope("transform"):
            input_tensor = tf.layers.dense(
                input_tensor,
                units=projection_width,
                activation=albert_modules.get_activation(config.hidden_act),
                kernel_initializer=albert_modules.create_initializer(
                    config.initializer_range))

            if config.get('ln_type', 'postln') == 'preln':
                input_tensor = input_tensor
            elif config.get('ln_type', 'postln') == 'postln':
                input_tensor = albert_modules.layer_norm(input_tensor)
            else:
                input_tensor = albert_modules.layer_norm(input_tensor)

        if embedding_projection is not None:
            # batch x seq x hidden, embedding x hidden
            print(input_tensor.get_shape(), embedding_projection.get_shape())
            input_tensor = tf.einsum("abc,dc->abd", input_tensor,
                                     embedding_projection)
        else:
            print("==no need for embedding projection==")
            input_tensor = input_tensor

        output_bias = tf.get_variable("output_bias",
                                      shape=[config.vocab_size],
                                      initializer=tf.zeros_initializer())
        # batch x seq x embedding
        logits = tf.einsum("abc,dc->abd", input_tensor, output_weights)
        logits = tf.nn.bias_add(logits, output_bias)

        input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3)
        width = input_shape_list[2]

        logits_tempered = logits / config.get("temperature", 1.0)

        flat_logits_tempered = tf.reshape(logits_tempered,
                                          [batch_size * seq_length, width])

        flat_logits_tempered_topk = top_k_logits(flat_logits_tempered,
                                                 int(config.vocab_size / 2))

        samples = tf.multinomial(flat_logits_tempered_topk,
                                 num_samples=config.get('gen_sample', 1),
                                 output_dtype=tf.int32)

        label_diff_ids = tf.equal(tf.cast(input_ids, tf.int32),
                                  tf.cast(input_ori_ids, tf.int32))
        label_diff_ids = tf.cast(label_diff_ids, tf.float32)
        print(label_diff_ids, "===label diff ids===")
        tf.summary.scalar(
            'label_diff_ids',
            tf.reduce_sum(label_diff_ids * tf.cast(input_mask, tf.float32)) /
            tf.reduce_sum(tf.cast(input_mask, tf.float32)))

        if config.get('gen_sample', 1) == 1:
            sampled_input_id = tf.reshape(samples, [batch_size, seq_length])
            if kargs.get('mask_method', 'all') == 'only_mask':
                label_diff_ids = tf.cast(label_diff_ids, tf.float32)
                samples = (1 - label_diff_ids) * tf.cast(
                    sampled_input_id, tf.float32) + label_diff_ids * tf.cast(
                        input_ori_ids, tf.float32)
                sampled_input_id = tf.cast(sampled_input_id, tf.int32)
        else:
            sampled_input_id = tf.reshape(
                samples, [batch_size, seq_length,
                          config.get('gen_sample', 1)])
            if kargs.get('mask_method', 'all') == 'only_mask':
                # batch x seq_length x 1
                label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1)
                label_diff_ids = tf.einsum(
                    'abc,cd->abd', label_diff_ids,
                    tf.ones((1, model_config.get('gen_sample', 1))))
                # batch x seq_length x 1
                input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1)
                input_ori_ids = tf.einsum(
                    'abc,cd->abd', input_ori_ids,
                    tf.ones((1, model_config.get('gen_sample', 1))))
                input_ori_ids = tf.cast(input_ori_ids, tf.float32)

                sampled_input_id = (1 - label_diff_ids) * tf.cast(
                    sampled_input_id,
                    tf.float32) + input_ori_ids * label_diff_ids
                sampled_input_id = tf.cast(sampled_input_id, tf.int32)

        return sampled_input_id