コード例 #1
0
def global_discriminator_logits(config, input_tensor, reuse=None, **kargs):
    """Get loss and log probs for the next sentence prediction."""
    # Simple binary classification. Note that 0 is "next sentence" and 1 is
    # "random sentence". This weight matrix is not used after pre-training.

    scope = kargs.get('scope', None)
    if scope:
        scope = scope + '/' + 'cls/seq_global'
    else:
        scope = 'cls/seq_global'
    tf.logging.info("**** nsp scope **** %s", str(scope))

    # with tf.variable_scope("cls/seq_relationship", reuse=reuse):
    with tf.variable_scope(scope, reuse=reuse):
        output_weights = tf.get_variable(
            "output_weights",
            shape=[2, config.hidden_size],
            initializer=albert_modules.create_initializer(
                config.initializer_range))
        output_bias = tf.get_variable("output_bias",
                                      shape=[2],
                                      initializer=tf.zeros_initializer())

        logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)

        return logits
コード例 #2
0
def global_feature_discriminator(config, input_tensor, labels, reuse=None, **kargs):
	"""Get loss and log probs for the next sentence prediction."""
	# Simple binary classification. Note that 0 is "next sentence" and 1 is
	# "random sentence". This weight matrix is not used after pre-training.

	scope = kargs.get('scope', None)
	if scope:
		scope = scope + '/' + 'cls/seq_global'
	else:
		scope = 'cls/seq_global'
	tf.logging.info("**** nsp scope **** %s", str(scope))

	# with tf.variable_scope("cls/seq_relationship", reuse=reuse):
	with tf.variable_scope(scope, reuse=reuse):
		output_weights = tf.get_variable(
				"output_weights",
				shape=[2, config.hidden_size],
				initializer=albert_modules.create_initializer(config.initializer_range))
		output_bias = tf.get_variable(
				"output_bias", shape=[2], initializer=tf.zeros_initializer())

		logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
		logits = tf.nn.bias_add(logits, output_bias)
		log_probs = tf.nn.log_softmax(logits, axis=-1)
		labels = tf.reshape(labels, [-1])
		one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
		per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
		loss = tf.reduce_mean(per_example_loss)
		return (loss, per_example_loss, log_probs)
コード例 #3
0
    def build_pooler(self, *args, **kargs):
        reuse = kargs["reuse"]
        layer_num = kargs.get("layer_num", -1)
        with tf.variable_scope(self.config.get("scope", "bert"), reuse=reuse):
            # self.sequence_output = self.all_encoder_layers[-1]
            self.sequence_output = self.get_encoder_layers(layer_num)

            # The "pooler" converts the encoded sequence tensor of shape
            # [batch_size, seq_length, hidden_size] to a tensor of shape
            # [batch_size, hidden_size]. This is necessary for segment-level
            # (or segment-pair-level) classification tasks where we need a fixed
            # dimensional representation of the segment.
            with tf.variable_scope("pooler"):
                # We "pool" the model by simply taking the hidden state corresponding
                # to the first token. We assume that this has been pre-trained
                first_token_tensor = tf.squeeze(self.sequence_output[:,
                                                                     0:1, :],
                                                axis=1)
                self.pooled_output = tf.layers.dense(
                    first_token_tensor,
                    self.config.hidden_size,
                    activation=tf.tanh,
                    kernel_initializer=albert_modules.create_initializer(
                        self.config.initializer_range))
コード例 #4
0
def token_generator_gumbel_normal(config, input_tensor, output_weights,
                                  input_ids, input_ori_ids, input_mask,
                                  **kargs):

    input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape_list[0]
    seq_length = input_shape_list[1]
    hidden_dims = input_shape_list[2]

    embedding_projection = kargs.get('embedding_projection', None)

    scope = kargs.get('scope', None)
    if scope:
        scope = scope + '/' + 'cls/predictions'
    else:
        scope = 'cls/predictions'

    tf.logging.info("**** mlm generator scope **** %s", str(scope))

    # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        if config.get('ln_type', 'postln') == 'preln':
            input_tensor = albert_modules.layer_norm(input_tensor)
        elif config.get('ln_type', 'postln') == 'postln':
            input_tensor = input_tensor
        else:
            input_tensor = input_tensor

        # if config.get("embedding", "factorized") == "factorized":
        # 	projection_width = config.hidden_size
        # else:
        # 	projection_width = config.embedding_size

        if config.get("embedding", "none_factorized") == "none_factorized":
            projection_width = config.hidden_size
            tf.logging.info("==not using embedding factorized==")
        else:
            projection_width = config.get('embedding_size', config.hidden_size)
            tf.logging.info(
                "==using embedding factorized: embedding size: %s==",
                str(projection_width))

        with tf.variable_scope("transform"):
            input_tensor = tf.layers.dense(
                input_tensor,
                units=projection_width,
                activation=albert_modules.get_activation(config.hidden_act),
                kernel_initializer=albert_modules.create_initializer(
                    config.initializer_range))

            if config.get('ln_type', 'postln') == 'preln':
                input_tensor = input_tensor
            elif config.get('ln_type', 'postln') == 'postln':
                input_tensor = albert_modules.layer_norm(input_tensor)
            else:
                input_tensor = albert_modules.layer_norm(input_tensor)

        if embedding_projection is not None:
            # batch x seq x hidden, embedding x hidden
            print(input_tensor.get_shape(), embedding_projection.get_shape())
            input_tensor = tf.einsum("abc,dc->abd", input_tensor,
                                     embedding_projection)
        else:
            print("==no need for embedding projection==")
            input_tensor = input_tensor

        output_bias = tf.get_variable("output_bias",
                                      shape=[config.vocab_size],
                                      initializer=tf.zeros_initializer())
        # batch x seq x embedding
        logits = tf.einsum("abc,dc->abd", input_tensor, output_weights)
        logits = tf.nn.bias_add(logits, output_bias)

        input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3)
        width = input_shape_list[2]

        logits_tempered = tf.nn.log_softmax(logits, axis=-1)

        # width=config.vocab_size
        flat_logits_tempered = tf.reshape(logits_tempered,
                                          [batch_size * seq_length, width])

        num_train_steps = kargs.get('num_train_steps', None)
        if num_train_steps and kargs.get('gumbel_anneal',
                                         "anneal") == 'anneal':
            tf.logging.info("****** apply annealed temperature ******* %s",
                            str(num_train_steps))
            annealed_temp = tf.train.polynomial_decay(
                config.get('gumbel_temperature', 1.0),
                tf.train.get_or_create_global_step(),
                kargs.get("num_train_steps", 10000),
                end_learning_rate=0.1,
                power=1.0,
                cycle=False)
        elif kargs.get('gumbel_anneal', "anneal") == 'softplus':
            tf.logging.info("****** apply auto-scale temperature *******")
            # batch x seq x dim
            with tf.variable_scope("gumbel_auto_scaling_temperature"):
                annealed_temp = tf.layers.dense(
                    input_tensor,
                    1,
                    activation=tf.nn.softplus,
                ) + 1.0
                annealed_temp = 1. / annealed_temp
                annealed_temp = tf.reshape(annealed_temp,
                                           [batch_size * seq_length, 1])
            if config.get('gen_sample', 1) > 1:
                tf.logging.info(
                    "****** apply auto-scale temperature for multi-sampling *******"
                )
                annealed_temp = tf.expand_dims(annealed_temp, -1)
        else:
            annealed_temp = 1.0
            tf.logging.info(
                "****** not apply annealed tenperature with fixed temp ******* %s",
                str(annealed_temp))

        # [batch x seq] x config.vocab_size x config.get('gen_sample', 1)
        sampled_logprob_temp, sampled_logprob = gumbel_softmax(
            flat_logits_tempered,
            temperature=annealed_temp,
            samples=config.get('gen_sample', 1))

        # argmax on config.vocab_size which is always axis=1
        # [batch x seq] x config.vocab_size x config.get('gen_sample', 1)
        # armax(logits+gumbel_samples) to sample a categoritical distribution
        if kargs.get('sampled_prob_id', True):
            tf.logging.info(
                "****** apply categorical sampled id of original logits *******"
            )
            sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob, axis=1),
                                         config.vocab_size,
                                         axis=1)  # sampled multiminal id
        else:
            tf.logging.info(
                "****** apply gumbel-softmax logprob for logits *******")
            sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob_temp,
                                                   axis=1),
                                         config.vocab_size,
                                         axis=1)  # sampled multiminal id

        # straight-through gumbel softmax estimator
        if kargs.get("straight_through", True):
            tf.logging.info(
                "****** apply straight_through_estimator without grl *******")
            sampled_id = tf.stop_gradient(sampled_hard_id -
                                          sampled_logprob_temp) + (
                                              sampled_logprob_temp)
        else:
            tf.logging.info(
                "****** apply gumbel-softmax probs without grl *******")
            sampled_id = flip_gradient(sampled_logprob_temp)

        sampled_binary_mask = kargs.get('sampled_binary_mask', None)
        if sampled_binary_mask is not None:
            label_diff_ids = tf.identity(
                sampled_binary_mask)  # 0 for original and 1 for replace
        else:
            label_diff_ids = tf.not_equal(
                tf.cast(input_ids, tf.int32),
                tf.cast(input_ori_ids,
                        tf.int32)  # 0 for original and 1 for replace
            )

        label_diff_ids = tf.cast(label_diff_ids, tf.float32)

        label_diff_ids = tf.expand_dims(label_diff_ids,
                                        axis=[-1])  # batch x seq x 1
        input_ori_ids = tf.one_hot(input_ori_ids,
                                   config.vocab_size)  # batch x seq x vocab
        input_ori_ids = tf.cast(input_ori_ids, tf.float32)

        if config.get('gen_sample', 1) == 1:
            sampled_input_id = tf.reshape(
                sampled_id, [batch_size, seq_length, config.vocab_size])
            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                label_diff_ids = tf.cast(label_diff_ids, tf.float32)
                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id, tf.float32
                ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32)
        else:
            sampled_input_id = tf.reshape(samples, [
                batch_size, seq_length, config.vocab_size,
                config.get('gen_sample', 1)
            ])
            label_diff_ids = tf.expand_dims(label_diff_ids,
                                            axis=-1)  # batch x seq x 1
            input_ori_ids = tf.expand_dims(input_ori_ids,
                                           axis=-1)  # batch x seq x vocab x 1

            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id,
                    tf.float32) + (1 - input_ori_ids) * label_diff_ids

        tf.logging.info("====generator use_tpu %s ====",
                        str(kargs.get('use_tpu', True)))
        if not kargs.get('use_tpu', True):
            tf.logging.info("====logging generator loss ====")
            sampled_not_equal_id = tf.not_equal(
                tf.cast(tf.argmax(sampled_input_id, axis=2), tf.int32),
                tf.cast(tf.argmax(input_ori_ids, axis=2), tf.int32))
            sampled_not_equal = tf.cast(sampled_not_equal_id,
                                        tf.float32) * tf.cast(
                                            input_mask, tf.float32)
            sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (
                1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32)))
            tf.summary.scalar('generator_sample_acc', sampled_not_equal)

        return sampled_input_id
コード例 #5
0
def token_generator_gumbel(config, input_tensor, output_weights, input_ids,
                           input_ori_ids, input_mask, **kargs):

    input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape_list[0]
    seq_length = input_shape_list[1]
    hidden_dims = input_shape_list[2]

    embedding_projection = kargs.get('embedding_projection', None)

    scope = kargs.get('scope', None)
    if scope:
        scope = scope + '/' + 'cls/predictions'
    else:
        scope = 'cls/predictions'

    tf.logging.info("**** mlm generator scope **** %s", str(scope))

    # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        if config.get('ln_type', 'postln') == 'preln':
            input_tensor = albert_modules.layer_norm(input_tensor)
        elif config.get('ln_type', 'postln') == 'postln':
            input_tensor = input_tensor
        else:
            input_tensor = input_tensor

        # if config.get("embedding", "factorized") == "factorized":
        # 	projection_width = config.hidden_size
        # else:
        # 	projection_width = config.embedding_size

        if config.get("embedding", "none_factorized") == "none_factorized":
            projection_width = config.hidden_size
            tf.logging.info("==not using embedding factorized==")
        else:
            projection_width = config.get('embedding_size', config.hidden_size)
            tf.logging.info(
                "==using embedding factorized: embedding size: %s==",
                str(projection_width))

        with tf.variable_scope("transform"):
            input_tensor = tf.layers.dense(
                input_tensor,
                units=projection_width,
                activation=albert_modules.get_activation(config.hidden_act),
                kernel_initializer=albert_modules.create_initializer(
                    config.initializer_range))

            if config.get('ln_type', 'postln') == 'preln':
                input_tensor = input_tensor
            elif config.get('ln_type', 'postln') == 'postln':
                input_tensor = albert_modules.layer_norm(input_tensor)
            else:
                input_tensor = albert_modules.layer_norm(input_tensor)

        if embedding_projection is not None:
            # batch x seq x hidden, embedding x hidden
            print(input_tensor.get_shape(), embedding_projection.get_shape())
            input_tensor = tf.einsum("abc,dc->abd", input_tensor,
                                     embedding_projection)
        else:
            print("==no need for embedding projection==")
            input_tensor = input_tensor

        output_bias = tf.get_variable("output_bias",
                                      shape=[config.vocab_size],
                                      initializer=tf.zeros_initializer())
        # batch x seq x embedding
        logits = tf.einsum("abc,dc->abd", input_tensor, output_weights)
        logits = tf.nn.bias_add(logits, output_bias)

        input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3)
        width = input_shape_list[2]

        # it seems no need for logits to be normalized
        logits_tempered = logits  #tf.nn.log_softmax(logits, axis=-1)

        # width=config.vocab_size
        flat_logits_tempered = tf.reshape(logits_tempered,
                                          [batch_size * seq_length, width])

        num_train_steps = kargs.get('num_train_steps', None)
        if num_train_steps and kargs.get('gumbel_anneal',
                                         "anneal") == 'anneal':
            tf.logging.info("****** apply annealed temperature ******* %s",
                            str(num_train_steps))
            temperature_warmup_steps = int(num_train_steps) * 0.1
            annealed_temp = tf.train.polynomial_decay(
                config.get('gumbel_temperature', 1.0),
                tf.train.get_or_create_global_step(),
                temperature_warmup_steps,
                end_learning_rate=0.01,
                power=1.0,
                cycle=False)
            gumbel_samples = None
            if not kargs.get('use_tpu', True):
                tf.summary.scalar('annealed_temp', annealed_temp)
        elif kargs.get('gumbel_anneal', "anneal") == 'softplus':
            tf.logging.info("****** apply auto-scale temperature *******")
            # batch x seq x dim
            with tf.variable_scope("gumbel_auto_scaling_temperature"):
                annealed_temp = tf.layers.dense(
                    input_tensor,
                    1,
                    activation=tf.nn.softplus,
                ) + 1.0
                annealed_temp = 1. / annealed_temp
                annealed_temp = tf.reshape(annealed_temp,
                                           [batch_size * seq_length, 1])
            if not kargs.get('use_tpu', True):
                tf.summary.scalar('softplus temperature',
                                  tf.reduce_mean(annealed_temp))
            if config.get('gen_sample', 1) > 1:
                tf.logging.info(
                    "****** apply auto-scale temperature for multi-sampling *******"
                )
                annealed_temp = tf.expand_dims(annealed_temp, -1)
            gumbel_samples = None
        elif kargs.get('gumbel_anneal', 'vqvae') == 'vqvae':
            temperature_warmup_steps = kargs.get("num_train_steps",
                                                 10000) * 0.1
            tf.logging.info(
                "****** apply t2t gumbel-softmax temperature annealing method with warm up steps %s ******* ",
                str(kargs.get("num_train_steps", 10000) * 0.1))
            steps = temperature_warmup_steps
            gumbel_samples = sample_gumbel(bert_utils.get_shape_list(
                flat_logits_tempered, expected_rank=2),
                                           samples=config.get('gen_sample', 1))
            gumbel_samples *= inverse_exp_decay(steps // 5) * 0.5
            annealed_temp_decay = 1.01 - inverse_lin_decay(
                steps)  # minimum temperature is set 0.2
            annealed_temp = tf.cond(
                tf.less(tf.random_uniform([]), 0.9),
                lambda: annealed_temp_decay, lambda: tf.random_uniform(
                    [], minval=0.5, maxval=1.0))  # 10% step for
            tf.logging.info(
                "****** apply t2t gumbel-softmax temperature annealing method ******* "
            )
            if not kargs.get('use_tpu', True):
                tf.summary.scalar('t2t_vqvae_stgs temperature', annealed_temp)
                tf.summary.scalar('t2t_vqvae_stgs temperature decay',
                                  annealed_temp_decay)
        elif kargs.get('gumbel_anneal', 'vqvae_v1') == 'vqvae_v1':
            temperature_warmup_steps = kargs.get("num_train_steps",
                                                 10000) * 0.1
            tf.logging.info(
                "****** apply t2t gumbel-softmax temperature annealing method with warm up steps %s ******* ",
                str(kargs.get("num_train_steps", 10000) * 0.1))
            steps = temperature_warmup_steps
            gumbel_samples = sample_gumbel(bert_utils.get_shape_list(
                flat_logits_tempered, expected_rank=2),
                                           samples=config.get('gen_sample', 1))
            # gumbel_samples *= inverse_exp_decay(steps)
            annealed_temp_decay = 1.01 - inverse_exp_decay(
                kargs.get("num_train_steps",
                          10000))  # minimum temperature is set 0.2
            annealed_temp = annealed_temp_decay
            # annealed_temp = tf.cond(
            # 			tf.less(tf.random_uniform([]), 0.95), lambda: annealed_temp_decay,
            # 			lambda: tf.random_uniform([], minval=0.5, maxval=1.0)) # 10% step for
            tf.logging.info(
                "****** apply sel-gan gumbel-softmax temperature annealing method ******* "
            )
            if not kargs.get('use_tpu', True):
                tf.summary.scalar('t2t_vqvae_stgs temperature', annealed_temp)
                tf.summary.scalar('t2t_vqvae_stgs temperature decay',
                                  annealed_temp_decay)
        elif kargs.get('gumbel_anneal', 'vqvae_v2') == 'vqvae_v2':
            temperature_warmup_steps = kargs.get("num_train_steps",
                                                 10000) * 0.1
            tf.logging.info(
                "****** apply t2t gumbel-softmax temperature annealing method with warm up steps %s ******* ",
                str(kargs.get("num_train_steps", 10000) * 0.1))
            steps = temperature_warmup_steps
            gumbel_samples = sample_gumbel(bert_utils.get_shape_list(
                flat_logits_tempered, expected_rank=2),
                                           samples=config.get('gen_sample', 1))
            # gumbel_samples *= inverse_exp_decay(steps)
            annealed_temp_decay = inverse_temp_exp_decay(
                kargs.get("num_train_steps", 10000),
                kargs.get("max_temp", 100))
            annealed_temp = annealed_temp_decay
            # annealed_temp = tf.cond(
            # 			tf.less(tf.random_uniform([]), 0.95), lambda: annealed_temp_decay,
            # 			lambda: tf.random_uniform([], minval=0.5, maxval=1.0)) # 10% step for
            tf.logging.info(
                "****** apply sel-gan-v2 gumbel-softmax temperature annealing method ******* "
            )
            tf.logging.info(
                "****** apply sel-gan-v2 gumbel-softmax num_train_steps:%s annealing method, temp:%s ******* ",
                str(kargs.get("num_train_steps", 10000)),
                str(kargs.get("max_temp", 100)))
            if not kargs.get('use_tpu', True):
                tf.summary.scalar('t2t_vqvae_stgs temperature', annealed_temp)
                tf.summary.scalar('t2t_vqvae_stgs temperature decay',
                                  annealed_temp_decay)
        else:
            annealed_temp = 0.01
            gumbel_samples = None
            tf.logging.info(
                "****** not apply annealed tenperature with fixed temp ******* %s",
                str(annealed_temp))
            if not kargs.get('use_tpu', True):
                tf.summary.scalar('gumbel_temperature', annealed_temp)

        # [batch x seq] x config.vocab_size x config.get('gen_sample', 1)
        if kargs.get('stable_gradient', True):
            sampled_logprob_temp, sampled_logprob = gumbel_softmax(
                flat_logits_tempered,
                temperature=annealed_temp,
                gumbel_samples=gumbel_samples,
                samples=config.get('gen_sample', 1),
                greedy=kargs.get("greedy", False))
            tf.logging.info(
                "****** apply normal derivate for gradient calculation *******"
            )
        else:
            sampled_logprob_temp, sampled_logprob = gumbel_softmax_custom_grad(
                flat_logits_tempered,
                temperature=annealed_temp,
                gumbel_samples=gumbel_samples,
                samples=config.get('gen_sample', 1))
            tf.logging.info(
                "****** apply log deriviate for stable gradient calculation *******"
            )

        # argmax on config.vocab_size which is always axis=1
        # [batch x seq] x config.vocab_size x config.get('gen_sample', 1)
        # armax(logits+gumbel_samples) to sample a categoritical distribution
        if kargs.get('sampled_prob_id', False):
            tf.logging.info(
                "****** apply categorical sampled id of original logits *******"
            )
            sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob, axis=1),
                                         config.vocab_size,
                                         axis=1)  # sampled multiminal id
        else:
            tf.logging.info(
                "****** apply gumbel-softmax logprob for logits *******")
            sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob_temp,
                                                   axis=1),
                                         config.vocab_size,
                                         axis=1)  # sampled multiminal id

        # straight-through gumbel softmax estimator
        if kargs.get('if_flip_grad', True):
            tf.logging.info("****** apply gradient flipping *******")
            sampled_logprob_temp_1 = flip_gradient(sampled_logprob_temp)
        else:
            tf.logging.info("****** not apply gradient flipping *******")
            sampled_logprob_temp_1 = sampled_logprob_temp
        if kargs.get("straight_through", True):
            tf.logging.info("****** apply straight_through_estimator *******")
            sampled_id = tf.stop_gradient(sampled_hard_id -
                                          sampled_logprob_temp) + (
                                              sampled_logprob_temp_1)
        else:
            tf.logging.info("****** apply gumbel-softmax probs *******")
            sampled_id = sampled_logprob_temp_1

        sampled_binary_mask = kargs.get('sampled_binary_mask', None)
        if sampled_binary_mask is not None:
            label_diff_ids = tf.identity(
                sampled_binary_mask)  # 0 for original and 1 for replace
        else:
            label_diff_ids = tf.not_equal(
                tf.cast(input_ids, tf.int32),
                tf.cast(input_ori_ids,
                        tf.int32)  # 0 for original and 1 for replace
            )

        label_diff_ids = tf.cast(label_diff_ids, tf.float32)

        label_diff_ids = tf.expand_dims(label_diff_ids,
                                        axis=[-1])  # batch x seq x 1
        input_ori_ids_1 = input_ori_ids
        input_ori_ids = tf.one_hot(input_ori_ids,
                                   config.vocab_size)  # batch x seq x vocab
        input_ori_ids = tf.cast(input_ori_ids, tf.float32)

        if config.get('gen_sample', 1) == 1:
            sampled_input_id = tf.reshape(
                sampled_id, [batch_size, seq_length, config.vocab_size])
            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                label_diff_ids = tf.cast(label_diff_ids, tf.float32)
                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id, tf.float32
                ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32)
            elif kargs.get('mask_method', 'only_mask') == 'all_mask':
                unk_mask = tf.cast(tf.math.equal(input_ori_ids_1, 100),
                                   tf.float32)  # not replace unk
                cls_mask = tf.cast(tf.math.equal(input_ori_ids_1, 101),
                                   tf.float32)  # not replace cls
                sep_mask = tf.cast(tf.math.equal(input_ori_ids_1, 102),
                                   tf.float32)  # not replace sep
                unsampled_mask = (1 -
                                  (unk_mask + cls_mask + sep_mask)) * tf.cast(
                                      input_mask, tf.float32)
                unsampled_mask = tf.expand_dims(unsampled_mask,
                                                axis=[-1])  # batch x seq x 1
                ori_input_mask = tf.expand_dims(input_mask,
                                                axis=[-1])  # batch x seq x 1

                tf.logging.info("****** all mask sample *******")
                sampled_input_id = unsampled_mask * tf.cast(
                    sampled_input_id,
                    tf.float32) + (1 - unsampled_mask) * tf.cast(
                        ori_input_mask, tf.float32) * tf.cast(
                            input_ori_ids, tf.float32)
        else:
            sampled_input_id = tf.reshape(samples, [
                batch_size, seq_length, config.vocab_size,
                config.get('gen_sample', 1)
            ])
            label_diff_ids = tf.expand_dims(label_diff_ids,
                                            axis=-1)  # batch x seq x 1
            input_ori_ids = tf.expand_dims(input_ori_ids,
                                           axis=-1)  # batch x seq x vocab x 1

            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id,
                    tf.float32) + (1 - input_ori_ids) * label_diff_ids

        tf.logging.info("====generator use_tpu %s ====",
                        str(kargs.get('use_tpu', True)))
        if not kargs.get('use_tpu', True):
            tf.logging.info("====logging generator loss ====")
            sampled_not_equal_id = tf.not_equal(
                tf.cast(tf.argmax(sampled_input_id, axis=2), tf.int32),
                tf.cast(tf.argmax(input_ori_ids, axis=2), tf.int32))
            sampled_equal_id = tf.equal(
                tf.cast(tf.argmax(sampled_input_id, axis=2), tf.int32),
                tf.cast(tf.argmax(input_ori_ids, axis=2), tf.int32))

            sampled_not_equal = tf.cast(sampled_not_equal_id,
                                        tf.float32) * tf.cast(
                                            input_mask, tf.float32)

            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32)))
                label_diff_ids_my = tf.cast(label_diff_ids, tf.float32)
            elif kargs.get('mask_method', 'only_mask') == 'all_mask':
                sampled_equal = tf.cast(
                    sampled_equal_id, tf.float32) * tf.cast(
                        tf.squeeze(unsampled_mask, axis=-1), tf.float32)
                tf.summary.scalar('generator_equal_sample_acc',
                                  tf.reduce_sum(sampled_equal))
                label_diff_ids_my = tf.cast(unsampled_mask, tf.float32)
                sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32)))
                sampled_equal = tf.reduce_sum(sampled_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32)))
            tf.summary.scalar('generator_sample_acc', sampled_not_equal)
            tf.summary.scalar("generator_valid_token",
                              tf.reduce_sum(input_mask))

            sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob_temp,
                                                   axis=1),
                                         config.vocab_size,
                                         axis=1)  # sampled multiminal id
            sampled_hard_id = tf.cast(sampled_hard_id, tf.float32)
            sampled_hard_id = tf.reshape(
                sampled_hard_id, [batch_size, seq_length, config.vocab_size])

            sampled_soft_id = tf.reshape(
                sampled_id, [batch_size, seq_length, config.vocab_size])
            sampled_hard_id *= label_diff_ids_my
            sampled_soft_id *= label_diff_ids_my

            hard_soft_bias = tf.reduce_sum(
                tf.sqrt(
                    tf.reduce_sum(tf.pow(sampled_hard_id - sampled_soft_id, 2),
                                  axis=-1))) / (1e-10 + tf.reduce_sum(
                                      tf.cast(label_diff_ids_my, tf.float32)))
            tf.summary.scalar('soft_hard_bias', hard_soft_bias)

        return sampled_input_id
コード例 #6
0
def efficient_attention_layer(from_tensor,
                              to_tensor,
                              attention_mask=None,
                              num_attention_heads=1,
                              size_per_head=512,
                              query_act=None,
                              key_act=None,
                              value_act=None,
                              attention_probs_dropout_prob=0.0,
                              initializer_range=0.02,
                              do_return_2d_tensor=False,
                              batch_size=None,
                              from_seq_length=None,
                              to_seq_length=None,
                              attention_fixed_size=None):
    """Performs multi-headed attention from `from_tensor` to `to_tensor`.

	This is an implementation of multi-headed attention based on "Attention
	is all you Need". If `from_tensor` and `to_tensor` are the same, then
	this is self-attention. Each timestep in `from_tensor` attends to the
	corresponding sequence in `to_tensor`, and returns a fixed-with vector.

	This function first projects `from_tensor` into a "query" tensor and
	`to_tensor` into "key" and "value" tensors. These are (effectively) a list
	of tensors of length `num_attention_heads`, where each tensor is of shape
	[batch_size, seq_length, size_per_head].

	Then, the query and key tensors are dot-producted and scaled. These are
	softmaxed to obtain attention probabilities. The value tensors are then
	interpolated by these probabilities, then concatenated back to a single
	tensor and returned.

	In practice, the multi-headed attention are done with transposes and
	reshapes rather than actual separate tensors.

	Args:
		from_tensor: float Tensor of shape [batch_size, from_seq_length,
			from_width].
		to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
		attention_mask: (optional) int32 Tensor of shape [batch_size,
			from_seq_length, to_seq_length]. The values should be 1 or 0. The
			attention scores will effectively be set to -infinity for any positions in
			the mask that are 0, and will be unchanged for positions that are 1.
		num_attention_heads: int. Number of attention heads.
		size_per_head: int. Size of each attention head.
		query_act: (optional) Activation function for the query transform.
		key_act: (optional) Activation function for the key transform.
		value_act: (optional) Activation function for the value transform.
		attention_probs_dropout_prob:
		initializer_range: float. Range of the weight initializer.
		do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
			* from_seq_length, num_attention_heads * size_per_head]. If False, the
			output will be of shape [batch_size, from_seq_length, num_attention_heads
			* size_per_head].
		batch_size: (Optional) int. If the input is 2D, this might be the batch size
			of the 3D version of the `from_tensor` and `to_tensor`.
		from_seq_length: (Optional) If the input is 2D, this might be the seq length
			of the 3D version of the `from_tensor`.
		to_seq_length: (Optional) If the input is 2D, this might be the seq length
			of the 3D version of the `to_tensor`.

	Returns:
		float Tensor of shape [batch_size, from_seq_length,
			num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
			true, this will be of shape [batch_size * from_seq_length,
			num_attention_heads * size_per_head]).

	Raises:
		ValueError: Any of the arguments or tensor shapes are invalid.
	"""
    def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
                             seq_length, width):
        output_tensor = tf.reshape(
            input_tensor, [batch_size, seq_length, num_attention_heads, width])

        output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
        return output_tensor

    from_shape = bert_utils.get_shape_list(from_tensor, expected_rank=[2, 3])
    to_shape = bert_utils.get_shape_list(to_tensor, expected_rank=[2, 3])

    if len(from_shape) != len(to_shape):
        raise ValueError(
            "The rank of `from_tensor` must match the rank of `to_tensor`.")

    if len(from_shape) == 3:
        batch_size = from_shape[0]
        from_seq_length = from_shape[1]
        to_seq_length = to_shape[1]
    elif len(from_shape) == 2:
        if (batch_size is None or from_seq_length is None
                or to_seq_length is None):
            raise ValueError(
                "When passing in rank 2 tensors to attention_layer, the values "
                "for `batch_size`, `from_seq_length`, and `to_seq_length` "
                "must all be specified.")

    # Scalar dimensions referenced here:
    #   B = batch size (number of sequences)
    #   F = `from_tensor` sequence length
    #   T = `to_tensor` sequence length
    #   N = `num_attention_heads`
    #   H = `size_per_head`

    if attention_fixed_size:
        attention_head_size = attention_fixed_size
        tf.logging.info("==apply attention_fixed_size==",
                        str(attention_head_size))
    else:
        attention_head_size = size_per_head
        tf.logging.info("==apply attention_original_size==",
                        str(attention_head_size))

    from_tensor_2d = bert_utils.reshape_to_matrix(from_tensor)
    to_tensor_2d = bert_utils.reshape_to_matrix(to_tensor)

    # `query_layer` = [B*F, N*H]
    query_layer = tf.layers.dense(
        from_tensor_2d,
        num_attention_heads * attention_head_size,
        activation=query_act,
        name="query",
        kernel_initializer=albert_modules.create_initializer(
            initializer_range))

    # `key_layer` = [B*T, N*H]
    key_layer = tf.layers.dense(
        to_tensor_2d,
        num_attention_heads * attention_head_size,
        activation=key_act,
        name="key",
        kernel_initializer=albert_modules.create_initializer(
            initializer_range))

    # `value_layer` = [B*T, N*H]
    value_layer = tf.layers.dense(
        to_tensor_2d,
        num_attention_heads * attention_head_size,
        activation=value_act,
        name="value",
        kernel_initializer=albert_modules.create_initializer(
            initializer_range))

    # softmax(QK^T/sqrt(4))V
    #softmax(Q)softmax(K)^TV

    # `query_layer` = [B, N, F, H]
    query_layer = transpose_for_scores(query_layer, batch_size,
                                       num_attention_heads, from_seq_length,
                                       attention_head_size)

    # `key_layer` = [B, N, T, H]
    key_layer = transpose_for_scores(key_layer, batch_size,
                                     num_attention_heads, to_seq_length,
                                     attention_head_size)

    # `value_layer` = [B, N, T, H]
    value_layer = transpose_for_scores(value_layer, batch_size,
                                       num_attention_heads, to_seq_length,
                                       attention_head_size)

    # Take the dot product between "query" and "key" to get the raw
    # attention scores.
    # `attention_scores` = [B, N, H, H]<---[B, N, T, H] x [B, N, T, H]
    # key_mask = [B, T, 1, 1]

    attention_mask = tf.cast(
        tf.expand_dims(attention_mask[:, 0:1, :], axis=[2]), tf.float32)
    attention_mask = tf.cast(tf.expand_dims(attention_mask, axis=[3]),
                             tf.float32)
    # key_mask = [B, 1, T, 1]
    attention_mask = tf.reshape(attention_mask,
                                [batch_size, 1, to_seq_length, 1])
    adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
    attention_scores = tf.nn.log_softmax(key_layer + adder, axis=2)
    attention_probs = tf.exp(attention_scores)
    attention_probs = albert_modules.dropout(attention_probs,
                                             attention_probs_dropout_prob)

    key_value_scores = tf.matmul(attention_probs,
                                 value_layer,
                                 transpose_a=True)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    # [B, N, F, H] x [B, N, H, H]--->[B, N, F, H]
    context_layer = tf.matmul(tf.exp(tf.nn.log_softmax(query_layer, axis=-1)),
                              key_value_scores)

    # `context_layer` = [B, F, N, H]
    context_layer = tf.transpose(context_layer, [0, 2, 1, 3])

    if do_return_2d_tensor:
        # `context_layer` = [B*F, N*V]
        context_layer = tf.reshape(context_layer, [
            batch_size * from_seq_length,
            num_attention_heads * attention_head_size
        ])
    else:
        # `context_layer` = [B, F, N*V]
        context_layer = tf.reshape(context_layer, [
            batch_size, from_seq_length,
            num_attention_heads * attention_head_size
        ])

    return context_layer, attention_scores, value_layer
コード例 #7
0
def get_masked_lm_output(config, input_tensor, output_weights, positions,
							label_ids, label_weights,
							**kargs):

	reuse = kargs.get('reuse', False)
	embedding_projection = kargs.get('embedding_projection', None)
	"""Get loss and log probs for the masked LM."""
	input_tensor = tf.cast(input_tensor, tf.float32)
	positions = tf.cast(positions, tf.int32)
	label_ids = tf.cast(label_ids, tf.int32)
	label_weights = tf.cast(label_weights, tf.float32)

	input_tensor = bert_utils.gather_indexes(input_tensor, positions)
	"""
	flatten masked lm ids with positions
	"""

	scope = kargs.get('scope', None)
	if scope:
		scope = scope + '/' + 'cls/predictions'
	else:
		scope = 'cls/predictions'

	tf.logging.info("**** mlm scope **** %s", str(scope))

	# with tf.variable_scope("cls/predictions", reuse=reuse):
	with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
		# We apply one more non-linear transformation before the output layer.
		# This matrix is not used after pre-training.
		if config.get('ln_type', 'postln') == 'preln':
			input_tensor = albert_modules.layer_norm(input_tensor)
		elif config.get('ln_type', 'postln') == 'postln':
			input_tensor = input_tensor
		else:
			input_tensor = input_tensor

		if config.get("embedding", "factorized") == "factorized":
			projection_width = config.hidden_size
		else:
			projection_width = config.embedding_size

		with tf.variable_scope("transform"):
			input_tensor = tf.layers.dense(
					input_tensor,
					units=projection_width,
					activation=albert_modules.get_activation(config.hidden_act),
					kernel_initializer=albert_modules.create_initializer(
							config.initializer_range))
			if config.get('ln_type', 'postln') == 'preln':
				input_tensor = input_tensor
			elif config.get('ln_type', 'postln') == 'postln':
				input_tensor = albert_modules.layer_norm(input_tensor)
			else:
				input_tensor = albert_modules.layer_norm(input_tensor)

		if embedding_projection is not None:
			input_tensor = tf.matmul(input_tensor, 
								embedding_projection,
								transpose_b=True)
		else:
			print("==no need for embedding projection==")
			input_tensor = input_tensor

		# The output weights are the same as the input embeddings, but there is
		# an output-only bias for each token.
		output_bias = tf.get_variable(
				"output_bias",
				shape=[config.vocab_size],
				initializer=tf.zeros_initializer())
		logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
		# logits = tf.multiply(logits,
		# 						1.0 / math.sqrt(float(config.hidden_size)))
		# logits *= 2
		
		logits = tf.nn.bias_add(logits, output_bias)
		log_probs = tf.nn.log_softmax(logits, axis=-1)

		label_ids = tf.reshape(label_ids, [-1])
		label_weights = tf.reshape(label_weights, [-1])

		# one_hot_labels = tf.one_hot(
		# 		label_ids, depth=config.vocab_size, dtype=tf.float32)

		per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
													labels=tf.stop_gradient(label_ids),
													logits=logits)
		# per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])

		numerator = tf.reduce_sum(label_weights * per_example_loss)
		denominator = tf.reduce_sum(label_weights) + 1e-5

		# The `positions` tensor might be zero-padded (if the sequence is too
		# short to have the maximum number of predictions). The `label_weights`
		# tensor has a value of 1.0 for every real prediction and 0.0 for the
		# padding predictions.
		# per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
		# numerator = tf.reduce_sum(label_weights * per_example_loss)
		# denominator = tf.reduce_sum(label_weights) + 1e-5
		loss = numerator / denominator

	return (loss, per_example_loss, log_probs, label_weights)
コード例 #8
0
def transformer_cell(input_tensor,
                     attention_mask=None,
                     hidden_size=768,
                     num_hidden_layers=12,
                     num_attention_heads=12,
                     intermediate_size=3072,
                     intermediate_act_fn=gelu,
                     hidden_dropout_prob=0.1,
                     attention_probs_dropout_prob=0.1,
                     initializer_range=0.02,
                     do_return_all_layers=False,
                     shared_type=None,
                     adapter_fn=None):

    layer_input = bert_utils.reshape_to_matrix(input_tensor)
    with tf.variable_scope("layer_shared", reuse=tf.AUTO_REUSE):
        with tf.variable_scope("attention", reuse=tf.AUTO_REUSE):
            attention_heads = []
            with tf.variable_scope("self"):
                [attention_head,
                 attention_scores] = albert_modules.attention_layer(
                     from_tensor=layer_input,
                     to_tensor=layer_input,
                     attention_mask=attention_mask,
                     num_attention_heads=num_attention_heads,
                     size_per_head=attention_head_size,
                     attention_probs_dropout_prob=attention_probs_dropout_prob,
                     initializer_range=initializer_range,
                     do_return_2d_tensor=True,
                     batch_size=batch_size,
                     from_seq_length=seq_length,
                     to_seq_length=seq_length)
                attention_heads.append(attention_head)
                all_attention_scores.append(attention_scores)

            attention_output = None
            if len(attention_heads) == 1:
                attention_output = attention_heads[0]
            else:
                # In the case where we have other sequences, we just concatenate
                # them to the self-attention head before the projection.
                attention_output = tf.concat(attention_heads, axis=-1)

            # Run a linear projection of `hidden_size` then add a residual
            # with `layer_input`.
            with tf.variable_scope("output", reuse=tf.AUTO_REUSE):
                attention_output = tf.layers.dense(
                    attention_output,
                    hidden_size,
                    kernel_initializer=albert_modules.create_initializer(
                        initializer_range))
                attention_output = albert_modules.dropout(
                    attention_output, hidden_dropout_prob)

                if adapter_fn:
                    attention_output = adapter_fn(attention_output,
                                                  layer_idx=layer_idx)

                attention_output = albert_modules.layer_norm(attention_output +
                                                             layer_input)

        # The activation is only applied to the "intermediate" hidden layer.
        with tf.variable_scope('intermediate', reuse=tf.AUTO_REUSE):
            intermediate_output = tf.layers.dense(
                attention_output,
                intermediate_size,
                activation=intermediate_act_fn,
                kernel_initializer=albert_modules.create_initializer(
                    initializer_range))

        # Down-project back to `hidden_size` then add the residual.
        with tf.variable_scope('output', reuse=tf.AUTO_REUSE):
            layer_output = tf.layers.dense(
                intermediate_output,
                hidden_size,
                kernel_initializer=albert_modules.create_initializer(
                    initializer_range))
            layer_output = albert_modules.dropout(layer_output,
                                                  hidden_dropout_prob)

        if adapter_fn:
            layer_output = adapter_fn(attention_output, layer_idx=layer_idx)

        layer_output = albert_modules.layer_norm(layer_output +
                                                 attention_output)
コード例 #9
0
ファイル: token_generator.py プロジェクト: Beleiaya/BERT
def token_generator(config, input_tensor, output_weights, input_ids,
                    input_ori_ids, input_mask, **kargs):

    input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape_list[0]
    seq_length = input_shape_list[1]
    hidden_dims = input_shape_list[2]

    embedding_projection = kargs.get('embedding_projection', None)

    scope = kargs.get('scope', None)
    if scope:
        scope = scope + '/' + 'cls/predictions'
    else:
        scope = 'cls/predictions'

    tf.logging.info("**** mlm generator scope **** %s", str(scope))

    # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        if config.get('ln_type', 'postln') == 'preln':
            input_tensor = albert_modules.layer_norm(input_tensor)
        elif config.get('ln_type', 'postln') == 'postln':
            input_tensor = input_tensor
        else:
            input_tensor = input_tensor

        # if config.get("embedding", "factorized") == "factorized":
        # 	projection_width = config.hidden_size
        # else:
        # 	projection_width = config.embedding_size

        if config.get("embedding", "none_factorized") == "none_factorized":
            projection_width = config.hidden_size
            tf.logging.info("==not using embedding factorized==")
        else:
            projection_width = config.get('embedding_size', config.hidden_size)
            tf.logging.info(
                "==using embedding factorized: embedding size: %s==",
                str(projection_width))

        with tf.variable_scope("transform"):
            input_tensor = tf.layers.dense(
                input_tensor,
                units=projection_width,
                activation=albert_modules.get_activation(config.hidden_act),
                kernel_initializer=albert_modules.create_initializer(
                    config.initializer_range))

            if config.get('ln_type', 'postln') == 'preln':
                input_tensor = input_tensor
            elif config.get('ln_type', 'postln') == 'postln':
                input_tensor = albert_modules.layer_norm(input_tensor)
            else:
                input_tensor = albert_modules.layer_norm(input_tensor)

        if embedding_projection is not None:
            # batch x seq x hidden, embedding x hidden
            print(input_tensor.get_shape(), embedding_projection.get_shape())
            input_tensor = tf.einsum("abc,dc->abd", input_tensor,
                                     embedding_projection)
        else:
            print("==no need for embedding projection==")
            input_tensor = input_tensor

        output_bias = tf.get_variable("output_bias",
                                      shape=[config.vocab_size],
                                      initializer=tf.zeros_initializer())
        # batch x seq x embedding
        logits = tf.einsum("abc,dc->abd", input_tensor, output_weights)
        logits = tf.nn.bias_add(logits, output_bias)

        input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3)
        width = input_shape_list[2]

        if not kargs.get("apply_valid_vocab", False):
            logits = logits
            tf.logging.info("****** normal logits *******")
        elif kargs.get("apply_valid_vocab", False) == 'topk':
            prob, _ = top_k_softmax(logits, kargs.get('topk', 10))
            logits = tf.log(prob + 1e-10)
            tf.logging.info("****** topk logits *******")
        else:
            invalid_size = kargs.get("invalid_size", 106)
            invalid_mask = tf.cast(
                tf.ones((1, invalid_size)) * (-10000), tf.float32)
            valid_mask = tf.cast(
                tf.zeros((1, config.vocab_size - invalid_size)), tf.float32)
            invaild_mask = tf.concat([invalid_mask, valid_mask], axis=-1)
            #
            invaild_mask = tf.expand_dims(invaild_mask,
                                          axis=1)  # batch x seq x vocab
            logits += tf.cast(invaild_mask, tf.float32)
            tf.logging.info(
                "****** only valid logits ******* , invalid size: %s",
                str(invalid_size))

        logits_tempered = tf.nn.log_softmax(logits /
                                            config.get("temperature", 1.0))

        flat_logits_tempered = tf.reshape(logits_tempered,
                                          [batch_size * seq_length, width])

        # flat_logits_tempered_topk = top_k_logits(flat_logits_tempered, int(config.vocab_size/2
        if not kargs.get("greedy", False):
            sampled_logprob_temp, sampled_logprob = gumbel_softmax(
                flat_logits_tempered,
                temperature=1.0,
                samples=config.get('gen_sample', 1),
                greedy=kargs.get("greedy", False))

            samples = tf.argmax(sampled_logprob, axis=1)  # batch x seq
            tf.logging.info("****** normal sample *******")
        else:
            samples = tf.argmax(flat_logits_tempered, axis=-1)
            tf.logging.info("****** greedy sample *******")

        # samples = tf.multinomial(flat_logits_tempered,
        # 						num_samples=config.get('gen_sample', 1),
        # 						output_dtype=tf.int32)

        sampled_binary_mask = kargs.get('sampled_binary_mask', None)
        if sampled_binary_mask is not None:
            label_diff_ids = sampled_binary_mask  # 0 for original and 1 for replace
        else:
            label_diff_ids = tf.not_equal(
                tf.cast(input_ids, tf.int32),
                tf.cast(input_ori_ids,
                        tf.int32)  # 0 for original and 1 for replace
            )
        label_diff_ids = tf.cast(label_diff_ids, tf.float32)
        print(label_diff_ids, "===label diff ids===")
        if not kargs.get('use_tpu', True):
            tf.summary.scalar(
                'label_diff_ids',
                tf.reduce_sum(label_diff_ids * tf.cast(input_mask, tf.float32))
                / tf.reduce_sum(tf.cast(input_mask, tf.float32)))

        if config.get('gen_sample', 1) == 1:
            sampled_input_id = tf.reshape(samples, [batch_size, seq_length])
            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                label_diff_ids = tf.cast(label_diff_ids, tf.float32)
                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id, tf.float32
                ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32)
                sampled_input_id = tf.cast(sampled_input_id, tf.int32)
            elif kargs.get('mask_method', 'only_mask') == 'all_mask':
                input_ori_ids_1 = input_ori_ids
                unk_mask = tf.cast(tf.math.equal(input_ori_ids_1, 100),
                                   tf.float32)  # not replace unk
                cls_mask = tf.cast(tf.math.equal(input_ori_ids_1, 101),
                                   tf.float32)  # not replace cls
                sep_mask = tf.cast(tf.math.equal(input_ori_ids_1, 102),
                                   tf.float32)  # not replace sep
                unsampled_mask = (1 -
                                  (unk_mask + cls_mask + sep_mask)) * tf.cast(
                                      input_mask, tf.float32)
                # unsampled_mask = tf.expand_dims(unsampled_mask, axis=[-1]) # batch x seq x 1
                tf.logging.info("****** all mask sample *******")
                sampled_input_id = unsampled_mask * tf.cast(
                    sampled_input_id, tf.float32
                ) + (1 - unsampled_mask) * tf.cast(input_ori_ids, tf.float32)
                sampled_input_id = tf.cast(sampled_input_id, tf.int32)
        else:
            sampled_input_id = tf.reshape(
                samples, [batch_size, seq_length,
                          config.get('gen_sample', 1)])
            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                # batch x seq_length x 1
                label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1)
                label_diff_ids = tf.einsum(
                    'abc,cd->abd', label_diff_ids,
                    tf.ones((1, model_config.get('gen_sample', 1))))
                # batch x seq_length x 1
                input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1)
                input_ori_ids = tf.einsum(
                    'abc,cd->abd', input_ori_ids,
                    tf.ones((1, model_config.get('gen_sample', 1))))
                input_ori_ids = tf.cast(input_ori_ids, tf.float32)

                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id,
                    tf.float32) + (1 - input_ori_ids) * label_diff_ids
                sampled_input_id = tf.cast(sampled_input_id, tf.int32)

                input_mask = tf.expand_dims(input_mask, axis=-1)
                input_mask = tf.einsum(
                    'abc,cd->abd', input_mask,
                    tf.ones((1, model_config.get('gen_sample', 1))))
                input_mask = tf.cast(input_mask, tf.float32)

        if not kargs.get('use_tpu', True):

            sampled_not_equal_id = tf.not_equal(
                tf.cast(sampled_input_id, tf.int32),
                tf.cast(input_ori_ids, tf.int32))
            sampled_not_equal = tf.cast(sampled_not_equal_id,
                                        tf.float32) * tf.cast(
                                            input_mask, tf.float32)

            sampled_equal_id = tf.equal(tf.cast(sampled_input_id, tf.int32),
                                        tf.cast(input_ori_ids, tf.int32))

            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32)))
            elif kargs.get('mask_method', 'only_mask') == 'all_mask':
                sampled_equal = tf.cast(sampled_equal_id,
                                        tf.float32) * tf.cast(
                                            unsampled_mask, tf.float32)
                tf.summary.scalar('generator_equal_sample_acc',
                                  tf.reduce_sum(sampled_equal))
                sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32)))
                sampled_equal = tf.reduce_sum(sampled_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32)))
            tf.summary.scalar('generator_sample_acc', sampled_not_equal)

        # sampled_not_equal_id = tf.not_equal(
        # 		tf.cast(sampled_input_id, tf.int32),
        # 		tf.cast(input_ori_ids, tf.int32)
        # )
        # sampled_not_equal = tf.cast(sampled_not_equal_id, tf.float32) * tf.cast(input_mask, tf.float32)
        # sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32)))

        # if not kargs.get('use_tpu', True):
        # 	tf.summary.scalar('generator_sample_acc',
        # 					sampled_not_equal)

        return sampled_input_id
コード例 #10
0
ファイル: pretrain_albert.py プロジェクト: Beleiaya/BERT
def seq_mask_masked_lm_output(config, input_tensor, output_weights,
							input_mask, input_ori_ids, input_ids, 
							sampled_binary_mask, **kargs):

	input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
	batch_size = input_shape_list[0]
	seq_length = input_shape_list[1]
	hidden_dims = input_shape_list[2]

	embedding_projection = kargs.get('embedding_projection', None)

	scope = kargs.get('scope', None)
	if scope:
		scope = scope + '/' + 'cls/predictions'
	else:
		scope = 'cls/predictions'

	tf.logging.info("**** mlm generator scope **** %s", str(scope))

	# with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
	with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
		if config.get('ln_type', 'postln') == 'preln':
			input_tensor = albert_modules.layer_norm(input_tensor)
		elif config.get('ln_type', 'postln') == 'postln':
			input_tensor = input_tensor
		else:
			input_tensor = input_tensor

		if config.get("embedding", "factorized") == "factorized":
			projection_width = config.hidden_size
		else:
			projection_width = config.embedding_size

		with tf.variable_scope("transform"):
			input_tensor = tf.layers.dense(
					input_tensor,
					units=projection_width,
					activation=albert_modules.get_activation(config.hidden_act),
					kernel_initializer=albert_modules.create_initializer(
							config.initializer_range))

			if config.get('ln_type', 'postln') == 'preln':
				input_tensor = input_tensor
			elif config.get('ln_type', 'postln') == 'postln':
				input_tensor = albert_modules.layer_norm(input_tensor)
			else:
				input_tensor = albert_modules.layer_norm(input_tensor)

		if embedding_projection is not None:
			# batch x seq x hidden, embedding x hidden
			print(input_tensor.get_shape(), embedding_projection.get_shape())
			input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection)
		else:
			print("==no need for embedding projection==")
			input_tensor = input_tensor

		output_bias = tf.get_variable(
				"output_bias",
				shape=[config.vocab_size],
				initializer=tf.zeros_initializer())
		# batch x seq x embedding
		logits = tf.einsum("abc,dc->abd", input_tensor, output_weights)
		logits = tf.nn.bias_add(logits, output_bias)

		"""
		if input_ori_ids[i] is random pertubated, sampled_binary_mask[i]=1
		"""
		sampled_binary_mask = tf.cast(sampled_binary_mask, tf.float32)
		per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
												logits=logits,
												labels=tf.stop_gradient(input_ori_ids),
												)
		per_example_loss *= sampled_binary_mask
		loss = tf.reduce_sum(per_example_loss) / (1e-10 + tf.reduce_sum(sampled_binary_mask))

		return (loss, per_example_loss, logits, sampled_binary_mask)
コード例 #11
0
ファイル: token_generator.py プロジェクト: CBHell/BERT
def token_generator(config, input_tensor, output_weights, input_ids,
                    input_ori_ids, input_mask, **kargs):

    input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape_list[0]
    seq_length = input_shape_list[1]
    hidden_dims = input_shape_list[2]

    embedding_projection = kargs.get('embedding_projection', None)

    scope = kargs.get('scope', None)
    if scope:
        scope = scope + '/' + 'cls/predictions'
    else:
        scope = 'cls/predictions'

    tf.logging.info("**** mlm generator scope **** %s", str(scope))

    # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        if config.get('ln_type', 'postln') == 'preln':
            input_tensor = albert_modules.layer_norm(input_tensor)
        elif config.get('ln_type', 'postln') == 'postln':
            input_tensor = input_tensor
        else:
            input_tensor = input_tensor

        if config.get("embedding", "factorized") == "factorized":
            projection_width = config.hidden_size
        else:
            projection_width = config.embedding_size

        with tf.variable_scope("transform"):
            input_tensor = tf.layers.dense(
                input_tensor,
                units=projection_width,
                activation=albert_modules.get_activation(config.hidden_act),
                kernel_initializer=albert_modules.create_initializer(
                    config.initializer_range))

            if config.get('ln_type', 'postln') == 'preln':
                input_tensor = input_tensor
            elif config.get('ln_type', 'postln') == 'postln':
                input_tensor = albert_modules.layer_norm(input_tensor)
            else:
                input_tensor = albert_modules.layer_norm(input_tensor)

        if embedding_projection is not None:
            # batch x seq x hidden, embedding x hidden
            print(input_tensor.get_shape(), embedding_projection.get_shape())
            input_tensor = tf.einsum("abc,dc->abd", input_tensor,
                                     embedding_projection)
        else:
            print("==no need for embedding projection==")
            input_tensor = input_tensor

        output_bias = tf.get_variable("output_bias",
                                      shape=[config.vocab_size],
                                      initializer=tf.zeros_initializer())
        # batch x seq x embedding
        logits = tf.einsum("abc,dc->abd", input_tensor, output_weights)
        logits = tf.nn.bias_add(logits, output_bias)

        input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3)
        width = input_shape_list[2]

        logits_tempered = logits / config.get("temperature", 1.0)

        flat_logits_tempered = tf.reshape(logits_tempered,
                                          [batch_size * seq_length, width])

        flat_logits_tempered_topk = top_k_logits(flat_logits_tempered,
                                                 int(config.vocab_size / 2))

        samples = tf.multinomial(flat_logits_tempered_topk,
                                 num_samples=config.get('gen_sample', 1),
                                 output_dtype=tf.int32)

        label_diff_ids = tf.equal(tf.cast(input_ids, tf.int32),
                                  tf.cast(input_ori_ids, tf.int32))
        label_diff_ids = tf.cast(label_diff_ids, tf.float32)
        print(label_diff_ids, "===label diff ids===")
        tf.summary.scalar(
            'label_diff_ids',
            tf.reduce_sum(label_diff_ids * tf.cast(input_mask, tf.float32)) /
            tf.reduce_sum(tf.cast(input_mask, tf.float32)))

        if config.get('gen_sample', 1) == 1:
            sampled_input_id = tf.reshape(samples, [batch_size, seq_length])
            if kargs.get('mask_method', 'all') == 'only_mask':
                label_diff_ids = tf.cast(label_diff_ids, tf.float32)
                samples = (1 - label_diff_ids) * tf.cast(
                    sampled_input_id, tf.float32) + label_diff_ids * tf.cast(
                        input_ori_ids, tf.float32)
                sampled_input_id = tf.cast(sampled_input_id, tf.int32)
        else:
            sampled_input_id = tf.reshape(
                samples, [batch_size, seq_length,
                          config.get('gen_sample', 1)])
            if kargs.get('mask_method', 'all') == 'only_mask':
                # batch x seq_length x 1
                label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1)
                label_diff_ids = tf.einsum(
                    'abc,cd->abd', label_diff_ids,
                    tf.ones((1, model_config.get('gen_sample', 1))))
                # batch x seq_length x 1
                input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1)
                input_ori_ids = tf.einsum(
                    'abc,cd->abd', input_ori_ids,
                    tf.ones((1, model_config.get('gen_sample', 1))))
                input_ori_ids = tf.cast(input_ori_ids, tf.float32)

                sampled_input_id = (1 - label_diff_ids) * tf.cast(
                    sampled_input_id,
                    tf.float32) + input_ori_ids * label_diff_ids
                sampled_input_id = tf.cast(sampled_input_id, tf.int32)

        return sampled_input_id