Beispiel #1
0
    def build_encoder(self, input_ids, input_mask, hidden_dropout_prob,
                      attention_probs_dropout_prob, **kargs):
        reuse = kargs["reuse"]
        with tf.variable_scope(self.config.get("scope", "bert"), reuse=reuse):
            with tf.variable_scope("encoder"):
                # This converts a 2D mask of shape [batch_size, seq_length] to a 3D
                # mask of shape [batch_size, seq_length, seq_length] which is used
                # for the attention scores.
                attention_mask = albert_modules.create_attention_mask_from_input_mask(
                    input_ids, input_mask)

                # Run the stacked transformer.
                # `sequence_output` shape = [batch_size, seq_length, hidden_size].
                print("===number of hidden layers===",
                      self.config.num_hidden_layers)

                if self.config.ln_type == 'postln':
                    print('==apply post layer==')
                    transformer_model = albert_modules.transformer_model
                elif self.config.ln_type == 'preln':
                    transformer_model = albert_modules.prelln_transformer_model
                    print('==apply pre layer==')

                [self.all_encoder_layers,
                 self.all_attention_scores] = transformer_model(
                     input_tensor=self.embedding_output,
                     attention_mask=attention_mask,
                     hidden_size=self.config.hidden_size,
                     num_hidden_layers=self.config.num_hidden_layers,
                     num_attention_heads=self.config.num_attention_heads,
                     intermediate_size=self.config.intermediate_size,
                     intermediate_act_fn=albert_modules.get_activation(
                         self.config.hidden_act),
                     hidden_dropout_prob=hidden_dropout_prob,
                     attention_probs_dropout_prob=attention_probs_dropout_prob,
                     initializer_range=self.config.initializer_range,
                     do_return_all_layers=True,
                     shared_type=self.config.get("shared_type", None),
                     adapter_fn=bert_adapter_modules.get_adapter(
                         self.config.get('adapter_fn', None)))
Beispiel #2
0
def token_generator_gumbel_normal(config, input_tensor, output_weights,
                                  input_ids, input_ori_ids, input_mask,
                                  **kargs):

    input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape_list[0]
    seq_length = input_shape_list[1]
    hidden_dims = input_shape_list[2]

    embedding_projection = kargs.get('embedding_projection', None)

    scope = kargs.get('scope', None)
    if scope:
        scope = scope + '/' + 'cls/predictions'
    else:
        scope = 'cls/predictions'

    tf.logging.info("**** mlm generator scope **** %s", str(scope))

    # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        if config.get('ln_type', 'postln') == 'preln':
            input_tensor = albert_modules.layer_norm(input_tensor)
        elif config.get('ln_type', 'postln') == 'postln':
            input_tensor = input_tensor
        else:
            input_tensor = input_tensor

        # if config.get("embedding", "factorized") == "factorized":
        # 	projection_width = config.hidden_size
        # else:
        # 	projection_width = config.embedding_size

        if config.get("embedding", "none_factorized") == "none_factorized":
            projection_width = config.hidden_size
            tf.logging.info("==not using embedding factorized==")
        else:
            projection_width = config.get('embedding_size', config.hidden_size)
            tf.logging.info(
                "==using embedding factorized: embedding size: %s==",
                str(projection_width))

        with tf.variable_scope("transform"):
            input_tensor = tf.layers.dense(
                input_tensor,
                units=projection_width,
                activation=albert_modules.get_activation(config.hidden_act),
                kernel_initializer=albert_modules.create_initializer(
                    config.initializer_range))

            if config.get('ln_type', 'postln') == 'preln':
                input_tensor = input_tensor
            elif config.get('ln_type', 'postln') == 'postln':
                input_tensor = albert_modules.layer_norm(input_tensor)
            else:
                input_tensor = albert_modules.layer_norm(input_tensor)

        if embedding_projection is not None:
            # batch x seq x hidden, embedding x hidden
            print(input_tensor.get_shape(), embedding_projection.get_shape())
            input_tensor = tf.einsum("abc,dc->abd", input_tensor,
                                     embedding_projection)
        else:
            print("==no need for embedding projection==")
            input_tensor = input_tensor

        output_bias = tf.get_variable("output_bias",
                                      shape=[config.vocab_size],
                                      initializer=tf.zeros_initializer())
        # batch x seq x embedding
        logits = tf.einsum("abc,dc->abd", input_tensor, output_weights)
        logits = tf.nn.bias_add(logits, output_bias)

        input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3)
        width = input_shape_list[2]

        logits_tempered = tf.nn.log_softmax(logits, axis=-1)

        # width=config.vocab_size
        flat_logits_tempered = tf.reshape(logits_tempered,
                                          [batch_size * seq_length, width])

        num_train_steps = kargs.get('num_train_steps', None)
        if num_train_steps and kargs.get('gumbel_anneal',
                                         "anneal") == 'anneal':
            tf.logging.info("****** apply annealed temperature ******* %s",
                            str(num_train_steps))
            annealed_temp = tf.train.polynomial_decay(
                config.get('gumbel_temperature', 1.0),
                tf.train.get_or_create_global_step(),
                kargs.get("num_train_steps", 10000),
                end_learning_rate=0.1,
                power=1.0,
                cycle=False)
        elif kargs.get('gumbel_anneal', "anneal") == 'softplus':
            tf.logging.info("****** apply auto-scale temperature *******")
            # batch x seq x dim
            with tf.variable_scope("gumbel_auto_scaling_temperature"):
                annealed_temp = tf.layers.dense(
                    input_tensor,
                    1,
                    activation=tf.nn.softplus,
                ) + 1.0
                annealed_temp = 1. / annealed_temp
                annealed_temp = tf.reshape(annealed_temp,
                                           [batch_size * seq_length, 1])
            if config.get('gen_sample', 1) > 1:
                tf.logging.info(
                    "****** apply auto-scale temperature for multi-sampling *******"
                )
                annealed_temp = tf.expand_dims(annealed_temp, -1)
        else:
            annealed_temp = 1.0
            tf.logging.info(
                "****** not apply annealed tenperature with fixed temp ******* %s",
                str(annealed_temp))

        # [batch x seq] x config.vocab_size x config.get('gen_sample', 1)
        sampled_logprob_temp, sampled_logprob = gumbel_softmax(
            flat_logits_tempered,
            temperature=annealed_temp,
            samples=config.get('gen_sample', 1))

        # argmax on config.vocab_size which is always axis=1
        # [batch x seq] x config.vocab_size x config.get('gen_sample', 1)
        # armax(logits+gumbel_samples) to sample a categoritical distribution
        if kargs.get('sampled_prob_id', True):
            tf.logging.info(
                "****** apply categorical sampled id of original logits *******"
            )
            sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob, axis=1),
                                         config.vocab_size,
                                         axis=1)  # sampled multiminal id
        else:
            tf.logging.info(
                "****** apply gumbel-softmax logprob for logits *******")
            sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob_temp,
                                                   axis=1),
                                         config.vocab_size,
                                         axis=1)  # sampled multiminal id

        # straight-through gumbel softmax estimator
        if kargs.get("straight_through", True):
            tf.logging.info(
                "****** apply straight_through_estimator without grl *******")
            sampled_id = tf.stop_gradient(sampled_hard_id -
                                          sampled_logprob_temp) + (
                                              sampled_logprob_temp)
        else:
            tf.logging.info(
                "****** apply gumbel-softmax probs without grl *******")
            sampled_id = flip_gradient(sampled_logprob_temp)

        sampled_binary_mask = kargs.get('sampled_binary_mask', None)
        if sampled_binary_mask is not None:
            label_diff_ids = tf.identity(
                sampled_binary_mask)  # 0 for original and 1 for replace
        else:
            label_diff_ids = tf.not_equal(
                tf.cast(input_ids, tf.int32),
                tf.cast(input_ori_ids,
                        tf.int32)  # 0 for original and 1 for replace
            )

        label_diff_ids = tf.cast(label_diff_ids, tf.float32)

        label_diff_ids = tf.expand_dims(label_diff_ids,
                                        axis=[-1])  # batch x seq x 1
        input_ori_ids = tf.one_hot(input_ori_ids,
                                   config.vocab_size)  # batch x seq x vocab
        input_ori_ids = tf.cast(input_ori_ids, tf.float32)

        if config.get('gen_sample', 1) == 1:
            sampled_input_id = tf.reshape(
                sampled_id, [batch_size, seq_length, config.vocab_size])
            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                label_diff_ids = tf.cast(label_diff_ids, tf.float32)
                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id, tf.float32
                ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32)
        else:
            sampled_input_id = tf.reshape(samples, [
                batch_size, seq_length, config.vocab_size,
                config.get('gen_sample', 1)
            ])
            label_diff_ids = tf.expand_dims(label_diff_ids,
                                            axis=-1)  # batch x seq x 1
            input_ori_ids = tf.expand_dims(input_ori_ids,
                                           axis=-1)  # batch x seq x vocab x 1

            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id,
                    tf.float32) + (1 - input_ori_ids) * label_diff_ids

        tf.logging.info("====generator use_tpu %s ====",
                        str(kargs.get('use_tpu', True)))
        if not kargs.get('use_tpu', True):
            tf.logging.info("====logging generator loss ====")
            sampled_not_equal_id = tf.not_equal(
                tf.cast(tf.argmax(sampled_input_id, axis=2), tf.int32),
                tf.cast(tf.argmax(input_ori_ids, axis=2), tf.int32))
            sampled_not_equal = tf.cast(sampled_not_equal_id,
                                        tf.float32) * tf.cast(
                                            input_mask, tf.float32)
            sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (
                1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32)))
            tf.summary.scalar('generator_sample_acc', sampled_not_equal)

        return sampled_input_id
Beispiel #3
0
def token_generator_gumbel(config, input_tensor, output_weights, input_ids,
                           input_ori_ids, input_mask, **kargs):

    input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape_list[0]
    seq_length = input_shape_list[1]
    hidden_dims = input_shape_list[2]

    embedding_projection = kargs.get('embedding_projection', None)

    scope = kargs.get('scope', None)
    if scope:
        scope = scope + '/' + 'cls/predictions'
    else:
        scope = 'cls/predictions'

    tf.logging.info("**** mlm generator scope **** %s", str(scope))

    # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        if config.get('ln_type', 'postln') == 'preln':
            input_tensor = albert_modules.layer_norm(input_tensor)
        elif config.get('ln_type', 'postln') == 'postln':
            input_tensor = input_tensor
        else:
            input_tensor = input_tensor

        # if config.get("embedding", "factorized") == "factorized":
        # 	projection_width = config.hidden_size
        # else:
        # 	projection_width = config.embedding_size

        if config.get("embedding", "none_factorized") == "none_factorized":
            projection_width = config.hidden_size
            tf.logging.info("==not using embedding factorized==")
        else:
            projection_width = config.get('embedding_size', config.hidden_size)
            tf.logging.info(
                "==using embedding factorized: embedding size: %s==",
                str(projection_width))

        with tf.variable_scope("transform"):
            input_tensor = tf.layers.dense(
                input_tensor,
                units=projection_width,
                activation=albert_modules.get_activation(config.hidden_act),
                kernel_initializer=albert_modules.create_initializer(
                    config.initializer_range))

            if config.get('ln_type', 'postln') == 'preln':
                input_tensor = input_tensor
            elif config.get('ln_type', 'postln') == 'postln':
                input_tensor = albert_modules.layer_norm(input_tensor)
            else:
                input_tensor = albert_modules.layer_norm(input_tensor)

        if embedding_projection is not None:
            # batch x seq x hidden, embedding x hidden
            print(input_tensor.get_shape(), embedding_projection.get_shape())
            input_tensor = tf.einsum("abc,dc->abd", input_tensor,
                                     embedding_projection)
        else:
            print("==no need for embedding projection==")
            input_tensor = input_tensor

        output_bias = tf.get_variable("output_bias",
                                      shape=[config.vocab_size],
                                      initializer=tf.zeros_initializer())
        # batch x seq x embedding
        logits = tf.einsum("abc,dc->abd", input_tensor, output_weights)
        logits = tf.nn.bias_add(logits, output_bias)

        input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3)
        width = input_shape_list[2]

        # it seems no need for logits to be normalized
        logits_tempered = logits  #tf.nn.log_softmax(logits, axis=-1)

        # width=config.vocab_size
        flat_logits_tempered = tf.reshape(logits_tempered,
                                          [batch_size * seq_length, width])

        num_train_steps = kargs.get('num_train_steps', None)
        if num_train_steps and kargs.get('gumbel_anneal',
                                         "anneal") == 'anneal':
            tf.logging.info("****** apply annealed temperature ******* %s",
                            str(num_train_steps))
            temperature_warmup_steps = int(num_train_steps) * 0.1
            annealed_temp = tf.train.polynomial_decay(
                config.get('gumbel_temperature', 1.0),
                tf.train.get_or_create_global_step(),
                temperature_warmup_steps,
                end_learning_rate=0.01,
                power=1.0,
                cycle=False)
            gumbel_samples = None
            if not kargs.get('use_tpu', True):
                tf.summary.scalar('annealed_temp', annealed_temp)
        elif kargs.get('gumbel_anneal', "anneal") == 'softplus':
            tf.logging.info("****** apply auto-scale temperature *******")
            # batch x seq x dim
            with tf.variable_scope("gumbel_auto_scaling_temperature"):
                annealed_temp = tf.layers.dense(
                    input_tensor,
                    1,
                    activation=tf.nn.softplus,
                ) + 1.0
                annealed_temp = 1. / annealed_temp
                annealed_temp = tf.reshape(annealed_temp,
                                           [batch_size * seq_length, 1])
            if not kargs.get('use_tpu', True):
                tf.summary.scalar('softplus temperature',
                                  tf.reduce_mean(annealed_temp))
            if config.get('gen_sample', 1) > 1:
                tf.logging.info(
                    "****** apply auto-scale temperature for multi-sampling *******"
                )
                annealed_temp = tf.expand_dims(annealed_temp, -1)
            gumbel_samples = None
        elif kargs.get('gumbel_anneal', 'vqvae') == 'vqvae':
            temperature_warmup_steps = kargs.get("num_train_steps",
                                                 10000) * 0.1
            tf.logging.info(
                "****** apply t2t gumbel-softmax temperature annealing method with warm up steps %s ******* ",
                str(kargs.get("num_train_steps", 10000) * 0.1))
            steps = temperature_warmup_steps
            gumbel_samples = sample_gumbel(bert_utils.get_shape_list(
                flat_logits_tempered, expected_rank=2),
                                           samples=config.get('gen_sample', 1))
            gumbel_samples *= inverse_exp_decay(steps // 5) * 0.5
            annealed_temp_decay = 1.01 - inverse_lin_decay(
                steps)  # minimum temperature is set 0.2
            annealed_temp = tf.cond(
                tf.less(tf.random_uniform([]), 0.9),
                lambda: annealed_temp_decay, lambda: tf.random_uniform(
                    [], minval=0.5, maxval=1.0))  # 10% step for
            tf.logging.info(
                "****** apply t2t gumbel-softmax temperature annealing method ******* "
            )
            if not kargs.get('use_tpu', True):
                tf.summary.scalar('t2t_vqvae_stgs temperature', annealed_temp)
                tf.summary.scalar('t2t_vqvae_stgs temperature decay',
                                  annealed_temp_decay)
        elif kargs.get('gumbel_anneal', 'vqvae_v1') == 'vqvae_v1':
            temperature_warmup_steps = kargs.get("num_train_steps",
                                                 10000) * 0.1
            tf.logging.info(
                "****** apply t2t gumbel-softmax temperature annealing method with warm up steps %s ******* ",
                str(kargs.get("num_train_steps", 10000) * 0.1))
            steps = temperature_warmup_steps
            gumbel_samples = sample_gumbel(bert_utils.get_shape_list(
                flat_logits_tempered, expected_rank=2),
                                           samples=config.get('gen_sample', 1))
            # gumbel_samples *= inverse_exp_decay(steps)
            annealed_temp_decay = 1.01 - inverse_exp_decay(
                kargs.get("num_train_steps",
                          10000))  # minimum temperature is set 0.2
            annealed_temp = annealed_temp_decay
            # annealed_temp = tf.cond(
            # 			tf.less(tf.random_uniform([]), 0.95), lambda: annealed_temp_decay,
            # 			lambda: tf.random_uniform([], minval=0.5, maxval=1.0)) # 10% step for
            tf.logging.info(
                "****** apply sel-gan gumbel-softmax temperature annealing method ******* "
            )
            if not kargs.get('use_tpu', True):
                tf.summary.scalar('t2t_vqvae_stgs temperature', annealed_temp)
                tf.summary.scalar('t2t_vqvae_stgs temperature decay',
                                  annealed_temp_decay)
        elif kargs.get('gumbel_anneal', 'vqvae_v2') == 'vqvae_v2':
            temperature_warmup_steps = kargs.get("num_train_steps",
                                                 10000) * 0.1
            tf.logging.info(
                "****** apply t2t gumbel-softmax temperature annealing method with warm up steps %s ******* ",
                str(kargs.get("num_train_steps", 10000) * 0.1))
            steps = temperature_warmup_steps
            gumbel_samples = sample_gumbel(bert_utils.get_shape_list(
                flat_logits_tempered, expected_rank=2),
                                           samples=config.get('gen_sample', 1))
            # gumbel_samples *= inverse_exp_decay(steps)
            annealed_temp_decay = inverse_temp_exp_decay(
                kargs.get("num_train_steps", 10000),
                kargs.get("max_temp", 100))
            annealed_temp = annealed_temp_decay
            # annealed_temp = tf.cond(
            # 			tf.less(tf.random_uniform([]), 0.95), lambda: annealed_temp_decay,
            # 			lambda: tf.random_uniform([], minval=0.5, maxval=1.0)) # 10% step for
            tf.logging.info(
                "****** apply sel-gan-v2 gumbel-softmax temperature annealing method ******* "
            )
            tf.logging.info(
                "****** apply sel-gan-v2 gumbel-softmax num_train_steps:%s annealing method, temp:%s ******* ",
                str(kargs.get("num_train_steps", 10000)),
                str(kargs.get("max_temp", 100)))
            if not kargs.get('use_tpu', True):
                tf.summary.scalar('t2t_vqvae_stgs temperature', annealed_temp)
                tf.summary.scalar('t2t_vqvae_stgs temperature decay',
                                  annealed_temp_decay)
        else:
            annealed_temp = 0.01
            gumbel_samples = None
            tf.logging.info(
                "****** not apply annealed tenperature with fixed temp ******* %s",
                str(annealed_temp))
            if not kargs.get('use_tpu', True):
                tf.summary.scalar('gumbel_temperature', annealed_temp)

        # [batch x seq] x config.vocab_size x config.get('gen_sample', 1)
        if kargs.get('stable_gradient', True):
            sampled_logprob_temp, sampled_logprob = gumbel_softmax(
                flat_logits_tempered,
                temperature=annealed_temp,
                gumbel_samples=gumbel_samples,
                samples=config.get('gen_sample', 1),
                greedy=kargs.get("greedy", False))
            tf.logging.info(
                "****** apply normal derivate for gradient calculation *******"
            )
        else:
            sampled_logprob_temp, sampled_logprob = gumbel_softmax_custom_grad(
                flat_logits_tempered,
                temperature=annealed_temp,
                gumbel_samples=gumbel_samples,
                samples=config.get('gen_sample', 1))
            tf.logging.info(
                "****** apply log deriviate for stable gradient calculation *******"
            )

        # argmax on config.vocab_size which is always axis=1
        # [batch x seq] x config.vocab_size x config.get('gen_sample', 1)
        # armax(logits+gumbel_samples) to sample a categoritical distribution
        if kargs.get('sampled_prob_id', False):
            tf.logging.info(
                "****** apply categorical sampled id of original logits *******"
            )
            sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob, axis=1),
                                         config.vocab_size,
                                         axis=1)  # sampled multiminal id
        else:
            tf.logging.info(
                "****** apply gumbel-softmax logprob for logits *******")
            sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob_temp,
                                                   axis=1),
                                         config.vocab_size,
                                         axis=1)  # sampled multiminal id

        # straight-through gumbel softmax estimator
        if kargs.get('if_flip_grad', True):
            tf.logging.info("****** apply gradient flipping *******")
            sampled_logprob_temp_1 = flip_gradient(sampled_logprob_temp)
        else:
            tf.logging.info("****** not apply gradient flipping *******")
            sampled_logprob_temp_1 = sampled_logprob_temp
        if kargs.get("straight_through", True):
            tf.logging.info("****** apply straight_through_estimator *******")
            sampled_id = tf.stop_gradient(sampled_hard_id -
                                          sampled_logprob_temp) + (
                                              sampled_logprob_temp_1)
        else:
            tf.logging.info("****** apply gumbel-softmax probs *******")
            sampled_id = sampled_logprob_temp_1

        sampled_binary_mask = kargs.get('sampled_binary_mask', None)
        if sampled_binary_mask is not None:
            label_diff_ids = tf.identity(
                sampled_binary_mask)  # 0 for original and 1 for replace
        else:
            label_diff_ids = tf.not_equal(
                tf.cast(input_ids, tf.int32),
                tf.cast(input_ori_ids,
                        tf.int32)  # 0 for original and 1 for replace
            )

        label_diff_ids = tf.cast(label_diff_ids, tf.float32)

        label_diff_ids = tf.expand_dims(label_diff_ids,
                                        axis=[-1])  # batch x seq x 1
        input_ori_ids_1 = input_ori_ids
        input_ori_ids = tf.one_hot(input_ori_ids,
                                   config.vocab_size)  # batch x seq x vocab
        input_ori_ids = tf.cast(input_ori_ids, tf.float32)

        if config.get('gen_sample', 1) == 1:
            sampled_input_id = tf.reshape(
                sampled_id, [batch_size, seq_length, config.vocab_size])
            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                label_diff_ids = tf.cast(label_diff_ids, tf.float32)
                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id, tf.float32
                ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32)
            elif kargs.get('mask_method', 'only_mask') == 'all_mask':
                unk_mask = tf.cast(tf.math.equal(input_ori_ids_1, 100),
                                   tf.float32)  # not replace unk
                cls_mask = tf.cast(tf.math.equal(input_ori_ids_1, 101),
                                   tf.float32)  # not replace cls
                sep_mask = tf.cast(tf.math.equal(input_ori_ids_1, 102),
                                   tf.float32)  # not replace sep
                unsampled_mask = (1 -
                                  (unk_mask + cls_mask + sep_mask)) * tf.cast(
                                      input_mask, tf.float32)
                unsampled_mask = tf.expand_dims(unsampled_mask,
                                                axis=[-1])  # batch x seq x 1
                ori_input_mask = tf.expand_dims(input_mask,
                                                axis=[-1])  # batch x seq x 1

                tf.logging.info("****** all mask sample *******")
                sampled_input_id = unsampled_mask * tf.cast(
                    sampled_input_id,
                    tf.float32) + (1 - unsampled_mask) * tf.cast(
                        ori_input_mask, tf.float32) * tf.cast(
                            input_ori_ids, tf.float32)
        else:
            sampled_input_id = tf.reshape(samples, [
                batch_size, seq_length, config.vocab_size,
                config.get('gen_sample', 1)
            ])
            label_diff_ids = tf.expand_dims(label_diff_ids,
                                            axis=-1)  # batch x seq x 1
            input_ori_ids = tf.expand_dims(input_ori_ids,
                                           axis=-1)  # batch x seq x vocab x 1

            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id,
                    tf.float32) + (1 - input_ori_ids) * label_diff_ids

        tf.logging.info("====generator use_tpu %s ====",
                        str(kargs.get('use_tpu', True)))
        if not kargs.get('use_tpu', True):
            tf.logging.info("====logging generator loss ====")
            sampled_not_equal_id = tf.not_equal(
                tf.cast(tf.argmax(sampled_input_id, axis=2), tf.int32),
                tf.cast(tf.argmax(input_ori_ids, axis=2), tf.int32))
            sampled_equal_id = tf.equal(
                tf.cast(tf.argmax(sampled_input_id, axis=2), tf.int32),
                tf.cast(tf.argmax(input_ori_ids, axis=2), tf.int32))

            sampled_not_equal = tf.cast(sampled_not_equal_id,
                                        tf.float32) * tf.cast(
                                            input_mask, tf.float32)

            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32)))
                label_diff_ids_my = tf.cast(label_diff_ids, tf.float32)
            elif kargs.get('mask_method', 'only_mask') == 'all_mask':
                sampled_equal = tf.cast(
                    sampled_equal_id, tf.float32) * tf.cast(
                        tf.squeeze(unsampled_mask, axis=-1), tf.float32)
                tf.summary.scalar('generator_equal_sample_acc',
                                  tf.reduce_sum(sampled_equal))
                label_diff_ids_my = tf.cast(unsampled_mask, tf.float32)
                sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32)))
                sampled_equal = tf.reduce_sum(sampled_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32)))
            tf.summary.scalar('generator_sample_acc', sampled_not_equal)
            tf.summary.scalar("generator_valid_token",
                              tf.reduce_sum(input_mask))

            sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob_temp,
                                                   axis=1),
                                         config.vocab_size,
                                         axis=1)  # sampled multiminal id
            sampled_hard_id = tf.cast(sampled_hard_id, tf.float32)
            sampled_hard_id = tf.reshape(
                sampled_hard_id, [batch_size, seq_length, config.vocab_size])

            sampled_soft_id = tf.reshape(
                sampled_id, [batch_size, seq_length, config.vocab_size])
            sampled_hard_id *= label_diff_ids_my
            sampled_soft_id *= label_diff_ids_my

            hard_soft_bias = tf.reduce_sum(
                tf.sqrt(
                    tf.reduce_sum(tf.pow(sampled_hard_id - sampled_soft_id, 2),
                                  axis=-1))) / (1e-10 + tf.reduce_sum(
                                      tf.cast(label_diff_ids_my, tf.float32)))
            tf.summary.scalar('soft_hard_bias', hard_soft_bias)

        return sampled_input_id
Beispiel #4
0
def get_masked_lm_output(config, input_tensor, output_weights, positions,
							label_ids, label_weights,
							**kargs):

	reuse = kargs.get('reuse', False)
	embedding_projection = kargs.get('embedding_projection', None)
	"""Get loss and log probs for the masked LM."""
	input_tensor = tf.cast(input_tensor, tf.float32)
	positions = tf.cast(positions, tf.int32)
	label_ids = tf.cast(label_ids, tf.int32)
	label_weights = tf.cast(label_weights, tf.float32)

	input_tensor = bert_utils.gather_indexes(input_tensor, positions)
	"""
	flatten masked lm ids with positions
	"""

	scope = kargs.get('scope', None)
	if scope:
		scope = scope + '/' + 'cls/predictions'
	else:
		scope = 'cls/predictions'

	tf.logging.info("**** mlm scope **** %s", str(scope))

	# with tf.variable_scope("cls/predictions", reuse=reuse):
	with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
		# We apply one more non-linear transformation before the output layer.
		# This matrix is not used after pre-training.
		if config.get('ln_type', 'postln') == 'preln':
			input_tensor = albert_modules.layer_norm(input_tensor)
		elif config.get('ln_type', 'postln') == 'postln':
			input_tensor = input_tensor
		else:
			input_tensor = input_tensor

		if config.get("embedding", "factorized") == "factorized":
			projection_width = config.hidden_size
		else:
			projection_width = config.embedding_size

		with tf.variable_scope("transform"):
			input_tensor = tf.layers.dense(
					input_tensor,
					units=projection_width,
					activation=albert_modules.get_activation(config.hidden_act),
					kernel_initializer=albert_modules.create_initializer(
							config.initializer_range))
			if config.get('ln_type', 'postln') == 'preln':
				input_tensor = input_tensor
			elif config.get('ln_type', 'postln') == 'postln':
				input_tensor = albert_modules.layer_norm(input_tensor)
			else:
				input_tensor = albert_modules.layer_norm(input_tensor)

		if embedding_projection is not None:
			input_tensor = tf.matmul(input_tensor, 
								embedding_projection,
								transpose_b=True)
		else:
			print("==no need for embedding projection==")
			input_tensor = input_tensor

		# The output weights are the same as the input embeddings, but there is
		# an output-only bias for each token.
		output_bias = tf.get_variable(
				"output_bias",
				shape=[config.vocab_size],
				initializer=tf.zeros_initializer())
		logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
		# logits = tf.multiply(logits,
		# 						1.0 / math.sqrt(float(config.hidden_size)))
		# logits *= 2
		
		logits = tf.nn.bias_add(logits, output_bias)
		log_probs = tf.nn.log_softmax(logits, axis=-1)

		label_ids = tf.reshape(label_ids, [-1])
		label_weights = tf.reshape(label_weights, [-1])

		# one_hot_labels = tf.one_hot(
		# 		label_ids, depth=config.vocab_size, dtype=tf.float32)

		per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
													labels=tf.stop_gradient(label_ids),
													logits=logits)
		# per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])

		numerator = tf.reduce_sum(label_weights * per_example_loss)
		denominator = tf.reduce_sum(label_weights) + 1e-5

		# The `positions` tensor might be zero-padded (if the sequence is too
		# short to have the maximum number of predictions). The `label_weights`
		# tensor has a value of 1.0 for every real prediction and 0.0 for the
		# padding predictions.
		# per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
		# numerator = tf.reduce_sum(label_weights * per_example_loss)
		# denominator = tf.reduce_sum(label_weights) + 1e-5
		loss = numerator / denominator

	return (loss, per_example_loss, log_probs, label_weights)
Beispiel #5
0
def token_generator(config, input_tensor, output_weights, input_ids,
                    input_ori_ids, input_mask, **kargs):

    input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape_list[0]
    seq_length = input_shape_list[1]
    hidden_dims = input_shape_list[2]

    embedding_projection = kargs.get('embedding_projection', None)

    scope = kargs.get('scope', None)
    if scope:
        scope = scope + '/' + 'cls/predictions'
    else:
        scope = 'cls/predictions'

    tf.logging.info("**** mlm generator scope **** %s", str(scope))

    # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        if config.get('ln_type', 'postln') == 'preln':
            input_tensor = albert_modules.layer_norm(input_tensor)
        elif config.get('ln_type', 'postln') == 'postln':
            input_tensor = input_tensor
        else:
            input_tensor = input_tensor

        # if config.get("embedding", "factorized") == "factorized":
        # 	projection_width = config.hidden_size
        # else:
        # 	projection_width = config.embedding_size

        if config.get("embedding", "none_factorized") == "none_factorized":
            projection_width = config.hidden_size
            tf.logging.info("==not using embedding factorized==")
        else:
            projection_width = config.get('embedding_size', config.hidden_size)
            tf.logging.info(
                "==using embedding factorized: embedding size: %s==",
                str(projection_width))

        with tf.variable_scope("transform"):
            input_tensor = tf.layers.dense(
                input_tensor,
                units=projection_width,
                activation=albert_modules.get_activation(config.hidden_act),
                kernel_initializer=albert_modules.create_initializer(
                    config.initializer_range))

            if config.get('ln_type', 'postln') == 'preln':
                input_tensor = input_tensor
            elif config.get('ln_type', 'postln') == 'postln':
                input_tensor = albert_modules.layer_norm(input_tensor)
            else:
                input_tensor = albert_modules.layer_norm(input_tensor)

        if embedding_projection is not None:
            # batch x seq x hidden, embedding x hidden
            print(input_tensor.get_shape(), embedding_projection.get_shape())
            input_tensor = tf.einsum("abc,dc->abd", input_tensor,
                                     embedding_projection)
        else:
            print("==no need for embedding projection==")
            input_tensor = input_tensor

        output_bias = tf.get_variable("output_bias",
                                      shape=[config.vocab_size],
                                      initializer=tf.zeros_initializer())
        # batch x seq x embedding
        logits = tf.einsum("abc,dc->abd", input_tensor, output_weights)
        logits = tf.nn.bias_add(logits, output_bias)

        input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3)
        width = input_shape_list[2]

        if not kargs.get("apply_valid_vocab", False):
            logits = logits
            tf.logging.info("****** normal logits *******")
        elif kargs.get("apply_valid_vocab", False) == 'topk':
            prob, _ = top_k_softmax(logits, kargs.get('topk', 10))
            logits = tf.log(prob + 1e-10)
            tf.logging.info("****** topk logits *******")
        else:
            invalid_size = kargs.get("invalid_size", 106)
            invalid_mask = tf.cast(
                tf.ones((1, invalid_size)) * (-10000), tf.float32)
            valid_mask = tf.cast(
                tf.zeros((1, config.vocab_size - invalid_size)), tf.float32)
            invaild_mask = tf.concat([invalid_mask, valid_mask], axis=-1)
            #
            invaild_mask = tf.expand_dims(invaild_mask,
                                          axis=1)  # batch x seq x vocab
            logits += tf.cast(invaild_mask, tf.float32)
            tf.logging.info(
                "****** only valid logits ******* , invalid size: %s",
                str(invalid_size))

        logits_tempered = tf.nn.log_softmax(logits /
                                            config.get("temperature", 1.0))

        flat_logits_tempered = tf.reshape(logits_tempered,
                                          [batch_size * seq_length, width])

        # flat_logits_tempered_topk = top_k_logits(flat_logits_tempered, int(config.vocab_size/2
        if not kargs.get("greedy", False):
            sampled_logprob_temp, sampled_logprob = gumbel_softmax(
                flat_logits_tempered,
                temperature=1.0,
                samples=config.get('gen_sample', 1),
                greedy=kargs.get("greedy", False))

            samples = tf.argmax(sampled_logprob, axis=1)  # batch x seq
            tf.logging.info("****** normal sample *******")
        else:
            samples = tf.argmax(flat_logits_tempered, axis=-1)
            tf.logging.info("****** greedy sample *******")

        # samples = tf.multinomial(flat_logits_tempered,
        # 						num_samples=config.get('gen_sample', 1),
        # 						output_dtype=tf.int32)

        sampled_binary_mask = kargs.get('sampled_binary_mask', None)
        if sampled_binary_mask is not None:
            label_diff_ids = sampled_binary_mask  # 0 for original and 1 for replace
        else:
            label_diff_ids = tf.not_equal(
                tf.cast(input_ids, tf.int32),
                tf.cast(input_ori_ids,
                        tf.int32)  # 0 for original and 1 for replace
            )
        label_diff_ids = tf.cast(label_diff_ids, tf.float32)
        print(label_diff_ids, "===label diff ids===")
        if not kargs.get('use_tpu', True):
            tf.summary.scalar(
                'label_diff_ids',
                tf.reduce_sum(label_diff_ids * tf.cast(input_mask, tf.float32))
                / tf.reduce_sum(tf.cast(input_mask, tf.float32)))

        if config.get('gen_sample', 1) == 1:
            sampled_input_id = tf.reshape(samples, [batch_size, seq_length])
            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                label_diff_ids = tf.cast(label_diff_ids, tf.float32)
                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id, tf.float32
                ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32)
                sampled_input_id = tf.cast(sampled_input_id, tf.int32)
            elif kargs.get('mask_method', 'only_mask') == 'all_mask':
                input_ori_ids_1 = input_ori_ids
                unk_mask = tf.cast(tf.math.equal(input_ori_ids_1, 100),
                                   tf.float32)  # not replace unk
                cls_mask = tf.cast(tf.math.equal(input_ori_ids_1, 101),
                                   tf.float32)  # not replace cls
                sep_mask = tf.cast(tf.math.equal(input_ori_ids_1, 102),
                                   tf.float32)  # not replace sep
                unsampled_mask = (1 -
                                  (unk_mask + cls_mask + sep_mask)) * tf.cast(
                                      input_mask, tf.float32)
                # unsampled_mask = tf.expand_dims(unsampled_mask, axis=[-1]) # batch x seq x 1
                tf.logging.info("****** all mask sample *******")
                sampled_input_id = unsampled_mask * tf.cast(
                    sampled_input_id, tf.float32
                ) + (1 - unsampled_mask) * tf.cast(input_ori_ids, tf.float32)
                sampled_input_id = tf.cast(sampled_input_id, tf.int32)
        else:
            sampled_input_id = tf.reshape(
                samples, [batch_size, seq_length,
                          config.get('gen_sample', 1)])
            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                # batch x seq_length x 1
                label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1)
                label_diff_ids = tf.einsum(
                    'abc,cd->abd', label_diff_ids,
                    tf.ones((1, model_config.get('gen_sample', 1))))
                # batch x seq_length x 1
                input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1)
                input_ori_ids = tf.einsum(
                    'abc,cd->abd', input_ori_ids,
                    tf.ones((1, model_config.get('gen_sample', 1))))
                input_ori_ids = tf.cast(input_ori_ids, tf.float32)

                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id,
                    tf.float32) + (1 - input_ori_ids) * label_diff_ids
                sampled_input_id = tf.cast(sampled_input_id, tf.int32)

                input_mask = tf.expand_dims(input_mask, axis=-1)
                input_mask = tf.einsum(
                    'abc,cd->abd', input_mask,
                    tf.ones((1, model_config.get('gen_sample', 1))))
                input_mask = tf.cast(input_mask, tf.float32)

        if not kargs.get('use_tpu', True):

            sampled_not_equal_id = tf.not_equal(
                tf.cast(sampled_input_id, tf.int32),
                tf.cast(input_ori_ids, tf.int32))
            sampled_not_equal = tf.cast(sampled_not_equal_id,
                                        tf.float32) * tf.cast(
                                            input_mask, tf.float32)

            sampled_equal_id = tf.equal(tf.cast(sampled_input_id, tf.int32),
                                        tf.cast(input_ori_ids, tf.int32))

            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32)))
            elif kargs.get('mask_method', 'only_mask') == 'all_mask':
                sampled_equal = tf.cast(sampled_equal_id,
                                        tf.float32) * tf.cast(
                                            unsampled_mask, tf.float32)
                tf.summary.scalar('generator_equal_sample_acc',
                                  tf.reduce_sum(sampled_equal))
                sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32)))
                sampled_equal = tf.reduce_sum(sampled_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32)))
            tf.summary.scalar('generator_sample_acc', sampled_not_equal)

        # sampled_not_equal_id = tf.not_equal(
        # 		tf.cast(sampled_input_id, tf.int32),
        # 		tf.cast(input_ori_ids, tf.int32)
        # )
        # sampled_not_equal = tf.cast(sampled_not_equal_id, tf.float32) * tf.cast(input_mask, tf.float32)
        # sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32)))

        # if not kargs.get('use_tpu', True):
        # 	tf.summary.scalar('generator_sample_acc',
        # 					sampled_not_equal)

        return sampled_input_id
Beispiel #6
0
def seq_mask_masked_lm_output(config, input_tensor, output_weights,
							input_mask, input_ori_ids, input_ids, 
							sampled_binary_mask, **kargs):

	input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
	batch_size = input_shape_list[0]
	seq_length = input_shape_list[1]
	hidden_dims = input_shape_list[2]

	embedding_projection = kargs.get('embedding_projection', None)

	scope = kargs.get('scope', None)
	if scope:
		scope = scope + '/' + 'cls/predictions'
	else:
		scope = 'cls/predictions'

	tf.logging.info("**** mlm generator scope **** %s", str(scope))

	# with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
	with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
		if config.get('ln_type', 'postln') == 'preln':
			input_tensor = albert_modules.layer_norm(input_tensor)
		elif config.get('ln_type', 'postln') == 'postln':
			input_tensor = input_tensor
		else:
			input_tensor = input_tensor

		if config.get("embedding", "factorized") == "factorized":
			projection_width = config.hidden_size
		else:
			projection_width = config.embedding_size

		with tf.variable_scope("transform"):
			input_tensor = tf.layers.dense(
					input_tensor,
					units=projection_width,
					activation=albert_modules.get_activation(config.hidden_act),
					kernel_initializer=albert_modules.create_initializer(
							config.initializer_range))

			if config.get('ln_type', 'postln') == 'preln':
				input_tensor = input_tensor
			elif config.get('ln_type', 'postln') == 'postln':
				input_tensor = albert_modules.layer_norm(input_tensor)
			else:
				input_tensor = albert_modules.layer_norm(input_tensor)

		if embedding_projection is not None:
			# batch x seq x hidden, embedding x hidden
			print(input_tensor.get_shape(), embedding_projection.get_shape())
			input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection)
		else:
			print("==no need for embedding projection==")
			input_tensor = input_tensor

		output_bias = tf.get_variable(
				"output_bias",
				shape=[config.vocab_size],
				initializer=tf.zeros_initializer())
		# batch x seq x embedding
		logits = tf.einsum("abc,dc->abd", input_tensor, output_weights)
		logits = tf.nn.bias_add(logits, output_bias)

		"""
		if input_ori_ids[i] is random pertubated, sampled_binary_mask[i]=1
		"""
		sampled_binary_mask = tf.cast(sampled_binary_mask, tf.float32)
		per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
												logits=logits,
												labels=tf.stop_gradient(input_ori_ids),
												)
		per_example_loss *= sampled_binary_mask
		loss = tf.reduce_sum(per_example_loss) / (1e-10 + tf.reduce_sum(sampled_binary_mask))

		return (loss, per_example_loss, logits, sampled_binary_mask)
Beispiel #7
0
def token_generator(config, input_tensor, output_weights, input_ids,
                    input_ori_ids, input_mask, **kargs):

    input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape_list[0]
    seq_length = input_shape_list[1]
    hidden_dims = input_shape_list[2]

    embedding_projection = kargs.get('embedding_projection', None)

    scope = kargs.get('scope', None)
    if scope:
        scope = scope + '/' + 'cls/predictions'
    else:
        scope = 'cls/predictions'

    tf.logging.info("**** mlm generator scope **** %s", str(scope))

    # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        if config.get('ln_type', 'postln') == 'preln':
            input_tensor = albert_modules.layer_norm(input_tensor)
        elif config.get('ln_type', 'postln') == 'postln':
            input_tensor = input_tensor
        else:
            input_tensor = input_tensor

        if config.get("embedding", "factorized") == "factorized":
            projection_width = config.hidden_size
        else:
            projection_width = config.embedding_size

        with tf.variable_scope("transform"):
            input_tensor = tf.layers.dense(
                input_tensor,
                units=projection_width,
                activation=albert_modules.get_activation(config.hidden_act),
                kernel_initializer=albert_modules.create_initializer(
                    config.initializer_range))

            if config.get('ln_type', 'postln') == 'preln':
                input_tensor = input_tensor
            elif config.get('ln_type', 'postln') == 'postln':
                input_tensor = albert_modules.layer_norm(input_tensor)
            else:
                input_tensor = albert_modules.layer_norm(input_tensor)

        if embedding_projection is not None:
            # batch x seq x hidden, embedding x hidden
            print(input_tensor.get_shape(), embedding_projection.get_shape())
            input_tensor = tf.einsum("abc,dc->abd", input_tensor,
                                     embedding_projection)
        else:
            print("==no need for embedding projection==")
            input_tensor = input_tensor

        output_bias = tf.get_variable("output_bias",
                                      shape=[config.vocab_size],
                                      initializer=tf.zeros_initializer())
        # batch x seq x embedding
        logits = tf.einsum("abc,dc->abd", input_tensor, output_weights)
        logits = tf.nn.bias_add(logits, output_bias)

        input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3)
        width = input_shape_list[2]

        logits_tempered = logits / config.get("temperature", 1.0)

        flat_logits_tempered = tf.reshape(logits_tempered,
                                          [batch_size * seq_length, width])

        flat_logits_tempered_topk = top_k_logits(flat_logits_tempered,
                                                 int(config.vocab_size / 2))

        samples = tf.multinomial(flat_logits_tempered_topk,
                                 num_samples=config.get('gen_sample', 1),
                                 output_dtype=tf.int32)

        label_diff_ids = tf.equal(tf.cast(input_ids, tf.int32),
                                  tf.cast(input_ori_ids, tf.int32))
        label_diff_ids = tf.cast(label_diff_ids, tf.float32)
        print(label_diff_ids, "===label diff ids===")
        tf.summary.scalar(
            'label_diff_ids',
            tf.reduce_sum(label_diff_ids * tf.cast(input_mask, tf.float32)) /
            tf.reduce_sum(tf.cast(input_mask, tf.float32)))

        if config.get('gen_sample', 1) == 1:
            sampled_input_id = tf.reshape(samples, [batch_size, seq_length])
            if kargs.get('mask_method', 'all') == 'only_mask':
                label_diff_ids = tf.cast(label_diff_ids, tf.float32)
                samples = (1 - label_diff_ids) * tf.cast(
                    sampled_input_id, tf.float32) + label_diff_ids * tf.cast(
                        input_ori_ids, tf.float32)
                sampled_input_id = tf.cast(sampled_input_id, tf.int32)
        else:
            sampled_input_id = tf.reshape(
                samples, [batch_size, seq_length,
                          config.get('gen_sample', 1)])
            if kargs.get('mask_method', 'all') == 'only_mask':
                # batch x seq_length x 1
                label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1)
                label_diff_ids = tf.einsum(
                    'abc,cd->abd', label_diff_ids,
                    tf.ones((1, model_config.get('gen_sample', 1))))
                # batch x seq_length x 1
                input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1)
                input_ori_ids = tf.einsum(
                    'abc,cd->abd', input_ori_ids,
                    tf.ones((1, model_config.get('gen_sample', 1))))
                input_ori_ids = tf.cast(input_ori_ids, tf.float32)

                sampled_input_id = (1 - label_diff_ids) * tf.cast(
                    sampled_input_id,
                    tf.float32) + input_ori_ids * label_diff_ids
                sampled_input_id = tf.cast(sampled_input_id, tf.int32)

        return sampled_input_id