Ejemplo n.º 1
0
    def build_encoder(self, input_ids, input_mask, hidden_dropout_prob,
                      attention_probs_dropout_prob, **kargs):
        reuse = kargs["reuse"]
        input_shape = bert_utils.get_shape_list(input_ids,
                                                expected_rank=[2, 3])
        batch_size = input_shape[0]
        seq_length = input_shape[1]

        if input_mask is None:
            input_mask = tf.ones(shape=[batch_size, seq_length],
                                 dtype=tf.int32)

        with tf.variable_scope(self.config.get("scope", "bert"), reuse=reuse):
            with tf.variable_scope("encoder"):
                # This converts a 2D mask of shape [batch_size, seq_length] to a 3D
                # mask of shape [batch_size, seq_length, seq_length] which is used
                # for the attention scores.

                input_shape = bert_utils.get_shape_list(input_ids,
                                                        expected_rank=[2, 3])
                if len(input_shape) == 3:
                    tmp_input_ids = tf.argmax(input_ids, axis=-1)
                else:
                    tmp_input_ids = input_ids

                attention_mask = bert_modules.create_attention_mask_from_input_mask(
                    tmp_input_ids, input_mask)

                seq_type = kargs.get('seq_type', "None")

                if seq_type == "seq2seq":
                    if kargs.get("mask_type", "left2right") == "left2right":
                        mask_sequence = input_mask
                        tf.logging.info(
                            "==apply left2right LM model with casual mask==")
                    elif kargs.get("mask_type", "left2right") == "seq2seq":
                        token_type_ids = kargs.get("token_type_ids", None)
                        tf.logging.info(
                            "==apply left2right LM model with conditional casual mask=="
                        )
                        if token_type_ids is None:
                            token_type_ids = tf.zeros(
                                shape=[batch_size, seq_length], dtype=tf.int32)
                            tf.logging.info(
                                "==conditional mask is set to 0 and degenerate to left2right LM model=="
                            )
                        mask_sequence = token_type_ids
                    attention_mask = bert_utils.generate_seq2seq_mask(
                        attention_mask, mask_sequence, seq_type, **kargs)
                else:
                    tf.logging.info(
                        "==apply bi-directional LM model with bi-directional mask=="
                    )

                # Run the stacked transformer.
                # `sequence_output` shape = [batch_size, seq_length, hidden_size].

                if kargs.get('attention_type',
                             'efficient_attention') == 'normal_attention':
                    tf.logging.info("****** normal attention *******")
                    transformer_model = bert_modules.transformer_model
                elif kargs.get('attention_type',
                               'efficient_attention') == 'efficient_attention':
                    tf.logging.info("****** efficient attention *******")
                    transformer_model = bert_modules.transformer_efficient_model
                elif kargs.get('attention_type',
                               'efficient_attention') == 'rezero_transformer':
                    transformer_model = bert_modules.transformer_rezero_model
                    tf.logging.info("****** rezero_transformer *******")
                else:
                    tf.logging.info("****** normal attention *******")
                    transformer_model = bert_modules.transformer_model

                [
                    self.all_encoder_layers, self.all_attention_scores,
                    self.all_value_outputs
                ] = transformer_model(
                    input_tensor=self.embedding_output,
                    attention_mask=attention_mask,
                    hidden_size=self.config.hidden_size,
                    num_hidden_layers=self.config.num_hidden_layers,
                    num_attention_heads=self.config.num_attention_heads,
                    intermediate_size=self.config.intermediate_size,
                    intermediate_act_fn=bert_modules.get_activation(
                        self.config.hidden_act),
                    hidden_dropout_prob=hidden_dropout_prob,
                    attention_probs_dropout_prob=attention_probs_dropout_prob,
                    initializer_range=self.config.initializer_range,
                    do_return_all_layers=True,
                    attention_fixed_size=self.config.get(
                        'attention_fixed_size', None))
Ejemplo n.º 2
0
    def model_fn(features, labels, mode):

        shape_lst_a = bert_utils.get_shape_list(features['input_ids_a'])
        batch_size_a = shape_lst_a[0]
        total_length_a = shape_lst_a[1]

        shape_lst_b = bert_utils.get_shape_list(features['input_ids_b'])
        batch_size_b = shape_lst_b[0]
        total_length_b = shape_lst_b[1]

        features['input_ids_a'] = tf.reshape(features['input_ids_a'],
                                             [-1, model_config.max_length])
        features['segment_ids_a'] = tf.reshape(features['segment_ids_a'],
                                               [-1, model_config.max_length])
        features['input_mask_a'] = tf.cast(
            tf.not_equal(features['input_ids_a'], kargs.get('[PAD]', 0)),
            tf.int64)

        features['input_ids_b'] = tf.reshape(
            features['input_ids_b'],
            [-1, model_config.max_predictions_per_seq])
        features['segment_ids_b'] = tf.reshape(
            features['segment_ids_b'],
            [-1, model_config.max_predictions_per_seq])
        features['input_mask_b'] = tf.cast(
            tf.not_equal(features['input_ids_b'], kargs.get('[PAD]', 0)),
            tf.int64)

        features['batch_size'] = batch_size_a
        features['total_length_a'] = total_length_a
        features['total_length_b'] = total_length_b

        model_dict = {}
        for target in ["a", "b"]:
            model = bert_encoder(model_config,
                                 features,
                                 labels,
                                 mode,
                                 target,
                                 reuse=tf.AUTO_REUSE)
            model_dict[target] = model

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):
            (loss, per_example_loss, logits,
             transition_params) = multi_position_crf_classifier(
                 model_config, features, model_dict, num_labels, dropout_prob)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)

        try:
            params_size = model_io_fn.count_params(model_config.scope)
            print("==total params==", params_size)
        except:
            print("==not count params==")
        print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            print("==update_ops==", update_ops)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

            train_op, hooks = model_io_fn.get_ema_hooks(
                train_op,
                tvars,
                kargs.get('params_moving_average_decay', 0.99),
                scope,
                mode,
                first_stage_steps=opt_config.num_warmup_steps,
                two_stage=True)

            model_io_fn.set_saver()

            if kargs.get("task_index", 1) == 0 and kargs.get(
                    "run_config", None):
                training_hooks = []
            elif kargs.get("task_index", 1) == 0:
                model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                      kargs.get("num_storage_steps", 1000))

                training_hooks = model_io_fn.checkpoint_hook
            else:
                training_hooks = []

            if len(optimizer_fn.distributed_hooks) >= 1:
                training_hooks.extend(optimizer_fn.distributed_hooks)
            print(training_hooks, "==training_hooks==", "==task_index==",
                  kargs.get("task_index", 1))

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=loss,
                train_op=train_op,
                training_hooks=training_hooks)
            print(tf.global_variables(), "==global_variables==")
            if output_type == "sess":
                return {
                    "train": {
                        "loss": loss,
                        "logits": logits,
                        "train_op": train_op
                    },
                    "hooks": training_hooks
                }
            elif output_type == "estimator":
                return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            print(logits.get_shape(), "===logits shape===")

            label_weights = tf.cast(features['label_weights'], tf.int32)
            label_seq_length = tf.reduce_sum(label_weights, axis=-1)

            decode_tags, best_score = tf.contrib.crf.crf_decode(
                logits, transition_params, label_seq_length)

            _, hooks = model_io_fn.get_ema_hooks(
                None, None, kargs.get('params_moving_average_decay', 0.99),
                scope, mode)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={
                    'decode_tags': decode_tags,
                    "best_score": best_score,
                    "transition_params": transition_params,
                    "logits": logits
                },
                export_outputs={
                    "output":
                    tf.estimator.export.PredictOutput({
                        'decode_tags': decode_tags,
                        "best_score": best_score,
                        "transition_params": transition_params,
                        "logits": logits
                    })
                },
                prediction_hooks=[hooks])
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            _, hooks = model_io_fn.get_ema_hooks(
                None, None, kargs.get('params_moving_average_decay', 0.99),
                scope, mode)
            eval_hooks = []

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss": per_example_loss,
                        "logits": logits,
                        "loss": tf.reduce_mean(per_example_loss),
                        "feature": model.get_pooled_output()
                    }
                }
            elif output_type == "estimator":

                eval_metric_ops = eval_logtis(logits, features, num_labels,
                                              transition_params)

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metric_ops=eval_metric_ops,
                    evaluation_hooks=eval_hooks)
                return estimator_spec
        else:
            raise NotImplementedError()
Ejemplo n.º 3
0
def embedding_postprocessor(input_tensor,
														use_token_type=False,
														token_type_ids=None,
														token_type_vocab_size=16,
														token_type_embedding_name="token_type_embeddings",
														use_position_embeddings=True,
														position_embedding_name="position_embeddings",
														initializer_range=0.02,
														max_position_embeddings=512,
														dropout_prob=0.1):
	"""Performs various post-processing on a word embedding tensor.

	Args:
		input_tensor: float Tensor of shape [batch_size, seq_length,
			embedding_size].
		use_token_type: bool. Whether to add embeddings for `token_type_ids`.
		token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
			Must be specified if `use_token_type` is True.
		token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
		token_type_embedding_name: string. The name of the embedding table variable
			for token type ids.
		use_position_embeddings: bool. Whether to add position embeddings for the
			position of each token in the sequence.
		position_embedding_name: string. The name of the embedding table variable
			for positional embeddings.
		initializer_range: float. Range of the weight initialization.
		max_position_embeddings: int. Maximum sequence length that might ever be
			used with this model. This can be longer than the sequence length of
			input_tensor, but cannot be shorter.
		dropout_prob: float. Dropout probability applied to the final output tensor.

	Returns:
		float tensor with same shape as `input_tensor`.

	Raises:
		ValueError: One of the tensor shapes or input values is invalid.
	"""
	input_shape = bert_utils.get_shape_list(input_tensor, expected_rank=3)
	batch_size = input_shape[0]
	seq_length = input_shape[1]
	width = input_shape[2]

	if seq_length > max_position_embeddings:
		raise ValueError("The seq length (%d) cannot be greater than "
										 "`max_position_embeddings` (%d)" %
										 (seq_length, max_position_embeddings))

	output = input_tensor

	if use_token_type:
		if token_type_ids is None:
			raise ValueError("`token_type_ids` must be specified if"
											 "`use_token_type` is True.")
		token_type_table = tf.get_variable(
				name=token_type_embedding_name,
				shape=[token_type_vocab_size, width],
				initializer=create_initializer(initializer_range))
		# This vocab will be small so we always do one-hot here, since it is always
		# faster for a small vocabulary.
		flat_token_type_ids = tf.reshape(token_type_ids, [-1])
		one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
		token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
		token_type_embeddings = tf.reshape(token_type_embeddings,
																			 [batch_size, seq_length, width])
		output += token_type_embeddings

	if use_position_embeddings:
		full_position_embeddings = tf.get_variable(
				name=position_embedding_name,
				shape=[max_position_embeddings, width],
				initializer=create_initializer(initializer_range))
		# Since the position embedding table is a learned variable, we create it
		# using a (long) sequence length `max_position_embeddings`. The actual
		# sequence length might be shorter than this, for faster training of
		# tasks that do not have long sequences.
		#
		# So `full_position_embeddings` is effectively an embedding table
		# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
		# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
		# perform a slice.

		if seq_length < max_position_embeddings:
			position_embeddings = tf.slice(full_position_embeddings, [0, 0],
																		 [seq_length, -1])
		else:
			position_embeddings = full_position_embeddings

		# position_embeddings = tf.cond(tf.less(seq_length, max_position_embeddings), 
		# 												lambda:tf.slice(full_position_embeddings, [0, 0],
		# 																 [seq_length, -1]), 
		# 												lambda:full_position_embeddings)

		num_dims = len(output.shape.as_list())

		# Only the last two dimensions are relevant (`seq_length` and `width`), so
		# we broadcast among the first dimensions, which is typically just
		# the batch size.
		position_broadcast_shape = []
		for _ in range(num_dims - 2):
			position_broadcast_shape.append(1)
		position_broadcast_shape.extend([seq_length, width])
		position_embeddings = tf.reshape(position_embeddings,
																		 position_broadcast_shape)
		output += position_embeddings

	output = layer_norm_and_dropout(output, dropout_prob)
	return output
Ejemplo n.º 4
0
def token_generator_igr(config, input_tensor, output_weights, input_ids,
                        input_ori_ids, input_mask, **kargs):

    input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape_list[0]
    seq_length = input_shape_list[1]
    hidden_dims = input_shape_list[2]

    embedding_projection = kargs.get('embedding_projection', None)

    scope = kargs.get('scope', None)
    if scope:
        scope = scope + '/' + 'cls/predictions'
    else:
        scope = 'cls/predictions'

    tf.logging.info("**** mlm generator scope **** %s", str(scope))

    # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        if config.get('ln_type', 'postln') == 'preln':
            input_tensor = albert_modules.layer_norm(input_tensor)
        elif config.get('ln_type', 'postln') == 'postln':
            input_tensor = input_tensor
        else:
            input_tensor = input_tensor

        # if config.get("embedding", "factorized") == "factorized":
        # 	projection_width = config.hidden_size
        # else:
        # 	projection_width = config.embedding_size

        if config.get("embedding", "none_factorized") == "none_factorized":
            projection_width = config.hidden_size
            tf.logging.info("==not using embedding factorized==")
        else:
            projection_width = config.get('embedding_size', config.hidden_size)
            tf.logging.info(
                "==using embedding factorized: embedding size: %s==",
                str(projection_width))

        with tf.variable_scope("transform"):
            input_tensor = tf.layers.dense(
                input_tensor,
                units=projection_width,
                activation=albert_modules.get_activation(config.hidden_act),
                kernel_initializer=albert_modules.create_initializer(
                    config.initializer_range))

            if config.get('ln_type', 'postln') == 'preln':
                input_tensor = input_tensor
            elif config.get('ln_type', 'postln') == 'postln':
                input_tensor = albert_modules.layer_norm(input_tensor)
            else:
                input_tensor = albert_modules.layer_norm(input_tensor)

        if embedding_projection is not None:
            # batch x seq x hidden, embedding x hidden
            print(input_tensor.get_shape(), embedding_projection.get_shape())
            input_tensor = tf.einsum("abc,dc->abd", input_tensor,
                                     embedding_projection)
        else:
            print("==no need for embedding projection==")
            input_tensor = input_tensor

        output_bias = tf.get_variable("output_bias",
                                      shape=[config.vocab_size],
                                      initializer=tf.zeros_initializer())
        # batch x seq x embedding
        logits = tf.einsum("abc,dc->abd", input_tensor, output_weights)
        logits = tf.nn.bias_add(logits, output_bias)

        input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3)
        width = input_shape_list[2]

        # logits_tempered = tf.nn.log_softmax(logits, axis=-1)

        # width=config.vocab_size
        flat_logits_tempered = tf.reshape(logits,
                                          [batch_size * seq_length, width])

        num_train_steps = kargs.get('num_train_steps', None)
        if num_train_steps and kargs.get('gumbel_anneal',
                                         "anneal") == 'anneal':
            tf.logging.info("****** apply annealed temperature ******* %s",
                            str(num_train_steps))
            annealed_temp = tf.train.polynomial_decay(
                config.get('gumbel_temperature', 1.0),
                tf.train.get_or_create_global_step(),
                kargs.get("num_train_steps", 10000),
                end_learning_rate=0.1,
                power=1.0,
                cycle=False)
        elif kargs.get('gumbel_anneal', "anneal") == 'softplus':
            tf.logging.info("****** apply auto-scale temperature *******")
            # batch x seq x dim
            with tf.variable_scope("gumbel_auto_scaling_temperature"):
                annealed_temp = tf.layers.dense(
                    input_tensor,
                    1,
                    activation=tf.nn.softplus,
                ) + 1.0
                annealed_temp = 1. / annealed_temp
                annealed_temp = tf.reshape(annealed_temp,
                                           [batch_size * seq_length, 1])
                if config.get('gen_sample', 1) > 1:
                    tf.logging.info(
                        "****** apply auto-scale temperature for multi-sampling *******"
                    )
                    annealed_temp = tf.expand_dims(annealed_temp, -1)
        else:
            annealed_temp = 0.01
            tf.logging.info(
                "****** not apply annealed tenperature with fixed temp ******* %s",
                str(annealed_temp))

        # [batch x seq] x config.vocab_size x config.get('gen_sample', 1)
        sampled_logprob_temp, sampled_logprob = iso_gaussian_sample(
            flat_logits_tempered,
            temperature=annealed_temp,
            samples=config.get('gen_sample', 1))

        # argmax on config.vocab_size which is always axis=1
        # [batch x seq] x config.vocab_size x config.get('gen_sample', 1)
        # armax(logits+gumbel_samples) to sample a categoritical distribution
        if kargs.get('sampled_prob_id', True):
            tf.logging.info(
                "****** apply categorical sampled id of original logits *******"
            )
            sampled_id = tf.one_hot(tf.argmax(sampled_logprob, axis=1),
                                    config.vocab_size,
                                    axis=1)  # sampled multiminal id
        else:
            tf.logging.info(
                "****** apply gumbel-softmax logprob for logits *******")
            sampled_id = tf.one_hot(tf.argmax(sampled_logprob_temp, axis=1),
                                    config.vocab_size,
                                    axis=1)  # sampled multiminal id

        # straight-through gumbel softmax estimator
        if kargs.get("straight_through", True):
            tf.logging.info("****** apply straight_through_estimator *******")
            sampled_id = tf.stop_gradient(
                sampled_id -
                sampled_logprob_temp) + flip_gradient(sampled_logprob_temp)
        else:
            tf.logging.info("****** apply gumbel-softmax probs *******")
            sampled_id = flip_gradient(sampled_logprob_temp)

        sampled_binary_mask = kargs.get('sampled_binary_mask', None)
        if sampled_binary_mask is not None:
            label_diff_ids = tf.identity(
                sampled_binary_mask)  # 0 for original and 1 for replace
        else:
            label_diff_ids = tf.not_equal(
                tf.cast(input_ids, tf.int32),
                tf.cast(input_ori_ids,
                        tf.int32)  # 0 for original and 1 for replace
            )

        label_diff_ids = tf.cast(label_diff_ids, tf.float32)

        label_diff_ids = tf.expand_dims(label_diff_ids,
                                        axis=[-1])  # batch x seq x 1
        input_ori_ids = tf.one_hot(input_ori_ids,
                                   config.vocab_size)  # batch x seq x vocab
        input_ori_ids = tf.cast(input_ori_ids, tf.float32)

        if config.get('gen_sample', 1) == 1:
            sampled_input_id = tf.reshape(
                sampled_id, [batch_size, seq_length, config.vocab_size])
            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                label_diff_ids = tf.cast(label_diff_ids, tf.float32)
                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id, tf.float32
                ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32)
        else:
            sampled_input_id = tf.reshape(samples, [
                batch_size, seq_length, config.vocab_size,
                config.get('gen_sample', 1)
            ])
            label_diff_ids = tf.expand_dims(label_diff_ids,
                                            axis=-1)  # batch x seq x 1
            input_ori_ids = tf.expand_dims(input_ori_ids,
                                           axis=-1)  # batch x seq x vocab x 1

            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id,
                    tf.float32) + (1 - input_ori_ids) * label_diff_ids

        return sampled_input_id
Ejemplo n.º 5
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        if kargs.get('random_generator', '1') == '1':
            if mode in [
                    tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.TRAIN
            ]:
                input_ori_ids = features['input_ori_ids']

                # [output_ids,
                # sampled_binary_mask] = random_input_ids_generation(model_config,
                # 							features['input_ori_ids'],
                # 							features['input_mask'],
                # 							mask_probability=0.2,
                # 							replace_probability=0.1,
                # 							original_probability=0.1,
                # 							**kargs)

                [output_ids, sampled_binary_mask] = hmm_input_ids_generation(
                    model_config,
                    features['input_ori_ids'],
                    features['input_mask'], [
                        tf.cast(tf.constant(hmm_tran_prob), tf.float32)
                        for hmm_tran_prob in hmm_tran_prob_list
                    ],
                    mask_probability=0.2,
                    replace_probability=0.0,
                    original_probability=0.0,
                    mask_prior=tf.constant(mask_prior, tf.float32),
                    **kargs)

                features['input_ids'] = output_ids
                tf.logging.info("****** do random generator *******")
            else:
                sampled_binary_mask = None
        else:
            sampled_binary_mask = None

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        (nsp_loss, nsp_per_example_loss,
         nsp_log_prob) = pretrain.get_next_sentence_output(
             model_config,
             model.get_pooled_output(),
             features['next_sentence_labels'],
             reuse=tf.AUTO_REUSE,
             scope=generator_scope_prefix)

        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]

        if model_config.model_type == 'bert':
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output
            print("==apply bert masked lm==")
        elif model_config.model_type == 'albert':
            masked_lm_fn = pretrain_albert.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply albert masked lm==")
        else:
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply bert masked lm==")

        if sampled_binary_mask is not None:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = seq_masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 features['input_mask'],
                 features['input_ori_ids'],
                 features['input_ids'],
                 sampled_binary_mask,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table(),
                 scope=generator_scope_prefix)
            masked_lm_ids = features['input_ori_ids']
        else:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 masked_lm_positions,
                 masked_lm_ids,
                 masked_lm_weights,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table(),
                 scope=generator_scope_prefix)
        print(model_config.lm_ratio, '==mlm lm_ratio==')
        loss = model_config.lm_ratio * masked_lm_loss + 0.0 * nsp_loss

        if kargs.get("resample_discriminator", False):
            input_ori_ids = features['input_ori_ids']

            [output_ids, sampled_binary_mask
             ] = random_input_ids_generation(model_config,
                                             features['input_ori_ids'],
                                             features['input_mask'],
                                             mask_probability=0.2,
                                             replace_probability=0.1,
                                             original_probability=0.1)

            resample_features = {}
            for key in features:
                resample_features[key] = features[key]

            resample_features['input_ids'] = tf.identity(output_ids)
            model_resample = model_api(model_config,
                                       resample_features,
                                       labels,
                                       mode,
                                       target,
                                       reuse=tf.AUTO_REUSE,
                                       **kargs)

            tf.logging.info("**** apply discriminator resample **** ")
        else:
            model_resample = model
            resample_features = features
            tf.logging.info("**** not apply discriminator resample **** ")

        sampled_ids = token_generator(model_config,
                                      model_resample.get_sequence_output(),
                                      model_resample.get_embedding_table(),
                                      resample_features['input_ids'],
                                      resample_features['input_ori_ids'],
                                      resample_features['input_mask'],
                                      embedding_projection=model_resample.
                                      get_embedding_projection_table(),
                                      scope=generator_scope_prefix,
                                      mask_method='only_mask',
                                      use_tpu=kargs.get('use_tpu', True))

        if model_config.get('gen_sample', 1) == 1:
            input_ids = features['input_ori_ids']
            input_mask = features['input_mask']
            segment_ids = features['segment_ids']
        else:
            input_ids = tf.expand_dims(features['input_ori_ids'], axis=-1)
            # batch x seq_length x 1
            input_ids = tf.einsum(
                'abc,cd->abd', input_ids,
                tf.ones((1, model_config.get('gen_sample', 1))))
            input_ids = tf.cast(input_ids, tf.int32)

            input_shape_list = bert_utils.get_shape_list(input_ids,
                                                         expected_rank=3)
            batch_size = input_shape_list[0]
            seq_length = input_shape_list[1]
            gen_sample = input_shape_list[2]

            sampled_ids = tf.reshape(sampled_ids,
                                     [batch * gen_sample, seq_length])
            input_ids = tf.reshape(input_ids, [batch * gen_sample, seq_length])

            input_mask = tf.expand_dims(features['input_mask'], axis=-1)
            input_mask = tf.einsum(
                'abc,cd->abd', input_mask,
                tf.ones((1, model_config.get('gen_sample', 1))))
            input_mask = tf.cast(input_mask, tf.int32)

            segment_ids = tf.expand_dims(features['segmnet_ids'], axis=-1)
            segment_ids = tf.einsum(
                'abc,cd->abd', segment_ids,
                tf.ones((1, model_config.get('gen_sample', 1))))
            segment_ids = tf.cast(segment_ids, tf.int32)

            segment_ids = tf.reshape(segment_ids,
                                     [batch * gen_sample, seq_length])
            input_mask = tf.reshape(input_mask,
                                    [batch * gen_sample, seq_length])

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        if generator_scope_prefix:
            """
			"generator/cls/predictions"
			"""
            lm_pretrain_tvars = model_io_fn.get_params(
                generator_scope_prefix + "/cls/predictions",
                not_storage_params=not_storage_params)

            nsp_pretrain_vars = model_io_fn.get_params(
                generator_scope_prefix + "/cls/seq_relationship",
                not_storage_params=not_storage_params)
        else:
            lm_pretrain_tvars = model_io_fn.get_params(
                "cls/predictions", not_storage_params=not_storage_params)

            nsp_pretrain_vars = model_io_fn.get_params(
                "cls/seq_relationship", not_storage_params=not_storage_params)

        if model_config.get('embedding_scope', None) is not None:
            embedding_tvars = model_io_fn.get_params(
                model_config.get('embedding_scope', 'bert') + "/embeddings",
                not_storage_params=not_storage_params)
            pretrained_tvars.extend(embedding_tvars)

        pretrained_tvars.extend(lm_pretrain_tvars)
        pretrained_tvars.extend(nsp_pretrain_vars)
        tvars = pretrained_tvars

        print('==generator parameters==', tvars)

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu,
                restore_var_name=model_config.get('restore_var_name', []))
        else:
            scaffold_fn = None
        tf.add_to_collection("generator_loss", loss)
        return_dict = {
            "loss": loss,
            "tvars": tvars,
            "model": model,
            "sampled_ids": sampled_ids,  # batch x gen_sample, seg_length
            "sampled_input_ids": input_ids,  # batch x gen_sample, seg_length,
            "sampled_input_mask": input_mask,
            "sampled_segment_ids": segment_ids,
            "masked_lm_ids": masked_lm_ids,
            "masked_lm_weights": masked_lm_mask,
            "masked_lm_log_probs": masked_lm_log_probs,
            "masked_lm_example_loss": masked_lm_example_loss,
            "next_sentence_example_loss": nsp_per_example_loss,
            "next_sentence_log_probs": nsp_log_prob,
            "next_sentence_labels": features['next_sentence_labels'],
            "sampled_binary_mask": sampled_binary_mask
        }
        return return_dict
Ejemplo n.º 6
0
def get_losses(d_out_real, d_out_fake, **kargs):
    # 1:original, 0:fake

    input_shape_list = bert_utils.get_shape_list(d_out_real,
                                                 expected_rank=[1, 2, 3])

    batch_size = input_shape_list[0]
    gan_type = kargs.get('gan_type', 'standard')

    tf.logging.info("**** gan type **** %s", str(gan_type))

    if gan_type == 'standard':  # the non-satuating GAN loss
        d_loss_real = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=d_out_real,
                labels=tf.cast(tf.ones(batch_size), tf.int32)))
        d_loss_fake = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=d_out_fake,
                labels=tf.cast(tf.zeros(batch_size), tf.int32)))
        d_loss = d_loss_real + d_loss_fake

        g_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=d_out_fake,
                labels=tf.cast(tf.ones(batch_size), tf.int32)))
        tf.logging.info("**** gan type **** %s", str(gan_type))
    elif gan_type == 'JS':  # the vanilla GAN loss
        d_loss_real = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=d_out_real,
                labels=tf.cast(tf.ones(batch_size), tf.int32)))
        d_loss_fake = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=d_out_fake,
                labels=tf.cast(tf.zeros(batch_size), tf.int32)))
        d_loss = d_loss_real + d_loss_fake

        g_loss = -d_loss_fake
        tf.logging.info("**** gan type **** %s", str(gan_type))

    elif gan_type == 'KL':  # the GAN loss implicitly minimizing KL-divergence
        d_loss_real = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=d_out_real,
                labels=tf.cast(tf.ones(batch_size), tf.int32)))
        d_loss_fake = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=d_out_fake,
                labels=tf.cast(tf.zeros(batch_size), tf.int32)))
        d_loss = d_loss_real + d_loss_fake

        g_loss = tf.reduce_mean(-d_out_fake)
        tf.logging.info("**** gan type **** %s", str(gan_type))

    elif gan_type == 'hinge':  # the hinge loss
        d_loss_real = tf.reduce_mean(tf.nn.relu(1.0 - d_out_real))
        d_loss_fake = tf.reduce_mean(tf.nn.relu(1.0 + d_out_fake))
        d_loss = d_loss_real + d_loss_fake

        g_loss = -tf.reduce_mean(d_out_fake)
        tf.logging.info("**** gan type **** %s", str(gan_type))

    elif gan_type == 'tv':  # the total variation distance
        d_loss = tf.reduce_mean(tf.tanh(d_out_fake) - tf.tanh(d_out_real))
        g_loss = tf.reduce_mean(-tf.tanh(d_out_fake))
        tf.logging.info("**** gan type **** %s", str(gan_type))

    # elif gan_type == 'wgan-gp':  # WGAN-GP
    # 	d_loss = tf.reduce_mean(d_out_fake) - tf.reduce_mean(d_out_real)
    # 	GP = gradient_penalty(discriminator, x_real_onehot, x_fake_onehot_appr, config)
    # 	d_loss += GP

    # 	g_loss = -tf.reduce_mean(d_out_fake)

    elif gan_type == 'LS':  # LS-GAN
        d_loss_real = tf.reduce_mean(tf.squared_difference(d_out_real, 1.0))
        d_loss_fake = tf.reduce_mean(tf.square(d_out_fake))
        d_loss = d_loss_real + d_loss_fake

        g_loss = tf.reduce_mean(tf.squared_difference(d_out_fake, 1.0))
        tf.logging.info("**** gan type **** %s", str(gan_type))

    elif gan_type == 'RSGAN':  # relativistic standard GAN
        d_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=d_out_real - d_out_fake,
                labels=tf.cast(tf.ones(batch_size), tf.int32)))
        g_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=d_out_fake - d_out_real,
                labels=tf.cast(tf.ones(batch_size), tf.int32)))
        tf.logging.info("**** gan type **** %s", str(gan_type))

    else:
        raise NotImplementedError("Divergence '%s' is not implemented" %
                                  gan_type)

    if not kargs.get('use_tpu', True):
        tf.logging.info("====logging discriminator global loss ====")
        tf.summary.scalar('disc_loss', d_loss)

        tf.summary.scalar('gen_loss', g_loss)

    return {"gen_loss": g_loss, "disc_loss": d_loss}
Ejemplo n.º 7
0
def classifier(config, seq_output, input_ids, sampled_ids, input_mask,
               num_labels, dropout_prob, **kargs):
    """
	input_ids: original input ids
	sampled_ids: generated fake ids
	"""
    output_layer = seq_output
    hidden_size = output_layer.shape[-1].value

    unk_mask = tf.cast(tf.math.equal(input_ids, 100),
                       tf.float32)  # not replace unk
    cls_mask = tf.cast(tf.math.equal(input_ids, 101),
                       tf.float32)  # not replace cls
    sep_mask = tf.cast(tf.math.equal(input_ids, 102),
                       tf.float32)  # not replace sep

    none_replace_mask = unk_mask + cls_mask + sep_mask

    input_mask = tf.cast(input_mask, tf.int32)
    input_mask *= tf.cast(
        1 - none_replace_mask,
        tf.int32)  # cls, unk, sep are not considered as replace or original

    output_weights = tf.get_variable(
        "output_weights", [num_labels, hidden_size],
        initializer=tf.truncated_normal_initializer(stddev=0.02))

    output_bias = tf.get_variable("output_bias", [num_labels],
                                  initializer=tf.zeros_initializer())

    if config.get('ln_type', 'postln') == 'preln':
        output_layer = albert_modules.layer_norm(output_layer)
        print('====preln transformer====')
    elif config.get('ln_type', 'postln') == 'postln':
        output_layer = output_layer
        print('====postln transformer====')
    else:
        output_layer = output_layer
        print('====no layer layer_norm====')

    output_layer = tf.nn.dropout(output_layer, keep_prob=1 - dropout_prob)

    logits = tf.einsum("abc,dc->abd", output_layer, output_weights)
    logits = tf.nn.bias_add(logits, output_bias)  # batch x seq_length x 2

    input_ids = tf.cast(input_ids, tf.int32)

    input_shape_list = bert_utils.get_shape_list(sampled_ids,
                                                 expected_rank=[2, 3])
    if len(input_shape_list) == 3:
        tmp_sampled_ids = tf.argmax(sampled_ids,
                                    axis=-1)  # batch x seq x vocab
        tmp_sampled_ids = tf.cast(tmp_sampled_ids, tf.int32)
        tf.logging.info("****** gumbel 3-D sampled_ids *******")
    elif len(input_shape_list) == 2:
        tmp_sampled_ids = sampled_ids
        tmp_sampled_ids = tf.cast(tmp_sampled_ids, tf.int32)
        tf.logging.info("****** normal 2-D sampled_ids *******")

    ori_sampled_ids = kargs.get('ori_sampled_ids', None)
    if ori_sampled_ids is not None:
        input_shape_list = bert_utils.get_shape_list(ori_sampled_ids,
                                                     expected_rank=[2, 3])
        if len(input_shape_list) == 3:
            tmp_ori_sampled_ids = tf.argmax(ori_sampled_ids,
                                            axis=-1)  # batch x seq x vocab
            tmp_ori_sampled_ids = tf.cast(tmp_sampled_ori_ids, tf.int32)
            tf.logging.info("****** gumbel 3-D sampled_ids *******")
        elif len(input_shape_list) == 2:
            tmp_ori_sampled_ids = tf.cast(ori_sampled_ids, tf.int32)
            tf.logging.info("****** normal 2-D sampled_ids *******")

        masked_not_equal_mask = tf.cast(
            tf.not_equal(input_ids, tmp_ori_sampled_ids), tf.int32)
        masked_not_equal_mask *= tf.cast(input_mask, tf.int32)
    else:
        masked_not_equal_mask = None
    if masked_not_equal_mask is not None:
        tf.logging.info(
            "****** loss mask using masked token mask for masked tokens *******"
        )
        loss_mask = masked_not_equal_mask
    else:
        tf.logging.info(
            "****** loss mask using input_mask for all tokens *******")
        loss_mask = input_mask

    # original:0, replace:1
    not_equal_label_ids = tf.cast(tf.not_equal(input_ids, tmp_sampled_ids),
                                  tf.int32)
    not_equal_label_ids *= tf.cast(input_mask, tf.int32)

    if kargs.get('loss', 'cross_entropy') == 'cross_entropy':
        per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=tf.stop_gradient(not_equal_label_ids))
    elif kargs.get('loss', 'cross_entropy') == 'focal_loss':
        input_shape_list = bert_utils.get_shape_list(input_ids,
                                                     expected_rank=2)
        batch_size = input_shape_list[0]
        seq_length = input_shape_list[1]
        not_equal_label_ids_ = tf.reshape(not_equal_label_ids,
                                          [batch_size * seq_length])
        logits_ = tf.reshape(logits, [batch_size * seq_length, -1])
        per_example_loss, _ = loss_utils.focal_loss_binary_v2(
            config, logits_, not_equal_label_ids_)
        per_example_loss = tf.reshape(per_example_loss,
                                      [batch_size, seq_length])

    # loss = per_example_loss * tf.cast(loss_mask, tf.float32)
    # loss = tf.reduce_sum(loss) / (1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32)))

    equal_label_ids = (1 - tf.cast(not_equal_label_ids, tf.float32)) * tf.cast(
        loss_mask, tf.float32)
    equal_loss = tf.reduce_sum(per_example_loss * equal_label_ids)

    equal_loss_output = equal_loss / (1e-10 + tf.reduce_sum(equal_label_ids))

    not_equal_loss = tf.reduce_sum(
        per_example_loss *
        tf.cast(not_equal_label_ids, tf.float32))  # not equal:1, equal:0
    not_equal_loss_output = not_equal_loss / (
        1e-10 + tf.reduce_sum(tf.cast(not_equal_label_ids, tf.float32)))

    loss = (equal_loss + not_equal_loss) / (
        1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32)))

    tf.logging.info("====discriminator classifier use_tpu %s ====",
                    str(kargs.get('use_tpu', True)))
    if not kargs.get('use_tpu', True):
        tf.logging.info("====logging discriminator loss ====")
        tf.summary.scalar('mask_based_loss', loss)

        tf.summary.scalar(
            'equal_loss', equal_loss /
            (1e-10 + tf.reduce_sum(tf.cast(input_mask, tf.float32))))

        tf.summary.scalar(
            'not_equal_loss', not_equal_loss /
            (1e-10 + tf.reduce_sum(tf.cast(input_mask, tf.float32))))

        tf.summary.scalar(
            'loss_decomposition', loss - (equal_loss + not_equal_loss) /
            (1e-10 + tf.reduce_sum(tf.cast(input_mask, tf.float32))))

    return (loss, logits, per_example_loss)
Ejemplo n.º 8
0
def optimal_discriminator(config, true_model_dict, true_features_dict,
						fake_model_dict, fake_features_dict, **kargs):

	alpha = (1-0.15)/0.15

	sampled_ids = fake_features_dict['input_ids']
	input_shape_list = bert_utils.get_shape_list(fake_features_dict["input_ori_ids"], 
													expected_rank=[2,3])
	batch_size = input_shape_list[0]
	seq_length = input_shape_list[1]

	true_logits = tf.exp(tf.nn.log_softmax(tf.reshape(true_model_dict['masked_lm_log_probs'], [-1, config.vocab_size])))
	fake_logits = tf.exp(tf.nn.log_softmax(tf.reshape(fake_model_dict['masked_lm_log_probs'], [-1, config.vocab_size])))

	labels = tf.reshape(sampled_ids, [-1, 1]) # [batch x seq, 1]
	batch_idxs = tf.range(0, tf.shape(labels)[0])
	batch_idxs = tf.expand_dims(batch_idxs, 1)

	idxs = tf.concat([batch_idxs, labels], 1)
	y_true_pred = tf.gather_nd(true_logits, idxs)
	y_fake_pred = tf.gather_nd(fake_logits, idxs)

	disc_probs = (y_true_pred * (alpha+y_fake_pred)+1e-10) / ((y_fake_pred+alpha*y_true_pred+1e-10))  # batch x seq
	disc_probs = tf.expand_dims(disc_probs, axis=-1) # [batch x seq, 1]
	neg_probs = 1 - disc_probs + 1e-10
	logits = tf.log(tf.concat([disc_probs, neg_probs], axis=-1)+1e-10)

	logits = tf.reshape(logits, [batch_size, seq_length, -1])
	
	input_ids = tf.cast(fake_features_dict['input_ori_ids'], tf.int32)
	unk_mask = tf.cast(tf.equal(input_ids, 100), tf.float32) # not replace unk
	cls_mask =  tf.cast(tf.equal(input_ids, 101), tf.float32) # not replace cls
	sep_mask = tf.cast(tf.equal(input_ids, 102), tf.float32) # not replace sep

	none_replace_mask =  unk_mask + cls_mask + sep_mask

	input_mask = fake_features_dict['input_mask']
	input_mask = tf.cast(input_mask, tf.int32)
	input_mask *= tf.cast(1 - none_replace_mask, tf.int32) # cls, unk, sep are not considered as replace or original

	input_shape_list = bert_utils.get_shape_list(sampled_ids, expected_rank=[2,3])
	if len(input_shape_list) == 3:
		tmp_sampled_ids = tf.argmax(sampled_ids, axis=-1) # batch x seq x vocab
		tmp_sampled_ids = tf.cast(tmp_sampled_ids, tf.int32)
		tf.logging.info("****** gumbel 3-D sampled_ids *******")
	elif len(input_shape_list) == 2:
		tmp_sampled_ids = sampled_ids
		tmp_sampled_ids = tf.cast(tmp_sampled_ids, tf.int32)

	sampled_binary_mask = kargs.get('sampled_binary_mask', None)

	if sampled_binary_mask is not None:
		tf.logging.info("****** loss mask using masked token mask for masked tokens *******")
		loss_mask = sampled_binary_mask
	else:
		tf.logging.info("****** loss mask using input_mask for all tokens *******")
		loss_mask = input_mask

	not_equal_label_ids = tf.cast(tf.not_equal(input_ids, tmp_sampled_ids), tf.int32)
	not_equal_label_ids *= tf.cast(loss_mask, tf.int32)

	print(logits.get_shape(), "===disc logits shape==", not_equal_label_ids.get_shape(), "==label ids shape==")

	per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
													logits=logits,
													labels=tf.stop_gradient(not_equal_label_ids))

	equal_label_ids = (1 - tf.cast(not_equal_label_ids, tf.float32)) * tf.cast(loss_mask, tf.float32)
	equal_per_example_loss = per_example_loss * equal_label_ids
	equal_loss = tf.reduce_sum(equal_per_example_loss)
	equal_loss_all = equal_loss / (1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32)))
	equal_loss_output = equal_loss / (1e-10 + tf.reduce_sum(equal_label_ids))

	not_equal_per_example_loss = per_example_loss * tf.cast(not_equal_label_ids, tf.float32)
	not_equal_loss = tf.reduce_sum(not_equal_per_example_loss) # not equal:1, equal:0
	not_equal_loss_all = not_equal_loss / (1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32)))
	not_equal_loss_output = not_equal_loss / (1e-10 + tf.reduce_sum(tf.cast(not_equal_label_ids, tf.float32)))

	loss = (equal_loss + not_equal_loss) / (1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32)))
	# loss = equal_loss_output + not_equal_loss_output * 0.1
	tf.logging.info("====discriminator classifier use_tpu %s ====", str(kargs.get('use_tpu', True)))
	if not kargs.get('use_tpu', True):
		tf.logging.info("====logging discriminator loss ====")
		tf.summary.scalar('mask_based_loss', 
							loss)

		loss = per_example_loss * tf.cast(loss_mask, tf.float32)
		loss = tf.reduce_sum(loss) / (1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32)))

		tf.summary.scalar('equal_loss', 
							equal_loss/(1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32))))

		tf.summary.scalar('not_equal_loss', 
							not_equal_loss/(1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32))))

		tf.summary.scalar('loss_decomposition', 
							loss - (equal_loss+not_equal_loss)/(1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32))))

	return (loss, logits, per_example_loss)
Ejemplo n.º 9
0
	def model_fn(features, labels, mode):

		task_type = kargs.get("task_type", "cls")
		num_task = kargs.get('num_task', 1)
		temp = kargs.get('temp', 0.1)

		print("==task_type==", task_type)

		model_io_fn = model_io.ModelIO(model_io_config)
		label_ids = tf.cast(features["{}_label_ids".format(task_type)], dtype=tf.int32)

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
			is_training = True
		else:
			dropout_prob = 0.0
			is_training = False

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope

		if kargs.get("get_pooled_output", "pooled_output") == "pooled_output":
			pooled_feature = model.get_pooled_output()
		elif kargs.get("get_pooled_output", "task_output") == "task_output":
			pooled_feature_dict = model.get_task_output()
			pooled_feature = pooled_feature_dict['pooled_feature']

		shape_list = bert_utils.get_shape_list(pooled_feature_dict['feature_a'], 
												expected_rank=[2])
		batch_size = shape_list[0]

		if kargs.get('apply_head_proj', False):
			with tf.variable_scope(scope+"/head_proj", reuse=tf.AUTO_REUSE):
				feature_a = simclr_utils.projection_head(pooled_feature_dict['feature_a'], 
										is_training, 
										head_proj_dim=128,
										num_nlh_layers=1,
										head_proj_mode='nonlinear',
										name='head_contrastive')
				pooled_feature_dict['feature_a'] = feature_a

			with tf.variable_scope(scope+"/head_proj", reuse=tf.AUTO_REUSE):
				feature_b = simclr_utils.projection_head(pooled_feature_dict['feature_b'], 
										is_training, 
										head_proj_dim=128,
										num_nlh_layers=1,
										head_proj_mode='nonlinear',
										name='head_contrastive')
				pooled_feature_dict['feature_b'] = feature_b
			tf.logging.info("****** apply contrastive feature projection *******")
		else:
			feature_a = pooled_feature_dict['feature_a']
			feature_b = pooled_feature_dict['feature_b']
			tf.logging.info("****** not apply projection *******")

		loss_mask = tf.cast(features["{}_loss_multipiler".format(task_type)], tf.float32)
		
		if kargs.get('merge_mode', 'all') == 'all':
			input_ids = tf.concat([features['input_ids_a'], features['input_ids_b']], axis=0)
			hidden_repres = tf.concat([feature_a, feature_b], axis=0)
			sent_repres = tf.concat([pooled_feature_dict['sent_repres_a'], pooled_feature_dict['sent_repres_b']], axis=0)
			tf.logging.info("****** double batch *******")
		else:
			input_ids = features['input_ids_b']
			hidden_repres = feature_b
			sent_repres = pooled_feature_dict['sent_repres_b']
			tf.logging.info("****** single batch b *******")
		sequence_mask = tf.to_float(tf.not_equal(input_ids, 
											kargs.get('[PAD]', 0)))

		with tf.variable_scope("vae/connect", reuse=tf.AUTO_REUSE):
			with tf.variable_scope("z_mean"):
				z_mean = tf.layers.dense(
							hidden_repres,
							128,
							use_bias=None,
							activation=None,
							kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
				bn_z_mean = vae_utils.mean_normalize_scale(z_mean, 
												is_training, 
												"bn_mean", 
												tau=0.5,
												reuse=tf.AUTO_REUSE,
												**kargs)

			with tf.variable_scope("z_std"):
				z_std = tf.layers.dense(
							hidden_repres,
							128,
							use_bias=True,
							activation=tf.nn.relu,
							kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))	
				bn_z_std = vae_utils.std_normalize_scale(z_std, 
							is_training, 
							"bn_std", 
							tau=0.5,
							reuse=tf.AUTO_REUSE,
							**kargs)

			gaussian_noise = vae_utils.hidden_sampling(bn_z_mean, bn_z_std, **kargs)
			sent_repres_shape = bert_utils.get_shape_list(sent_repres, expected_rank=[3])
			with tf.variable_scope("vae/projection"):
				gaussian_noise = tf.layers.dense(
							gaussian_noise,
							sent_repres_shape[-1],
							use_bias=None,
							activation=None,
							kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
			sent_repres += tf.expand_dims(gaussian_noise, axis=1)

		with tf.variable_scope("vae/decoder", reuse=tf.AUTO_REUSE):
			sequence_output = dgcnn_utils.dgcnn(
												sent_repres, 
												sequence_mask,
												num_layers=model_config['cnn_num_layers'], 
												dilation_rates=model_config.get('cnn_dilation_rates', [1,2]),
												strides=model_config.get('cnn_dilation_rates', [1,1]),
												num_filters=model_config.get('cnn_num_filters', [128,128]), 
												kernel_sizes=model_config.get('cnn_filter_sizes', [3,3]), 
												is_training=is_training,
												scope_name="textcnn_encoder/textcnn/forward", 
												reuse=tf.AUTO_REUSE, 
												activation=tf.nn.relu,
												is_casual=model_config['is_casual'],
												padding=model_config.get('padding', 'same')
												)
			sequence_output_logits = model.build_other_output_logits(sequence_output, reuse=tf.AUTO_REUSE)
		resc_loss = vae_utils.reconstruction_loss(sequence_output_logits, 
												input_ids,
												name="decoder_resc",
												use_tpu=False)
		kl_loss = vae_utils.kl_loss(bn_z_mean, bn_z_std, 
									opt_config.get('num_train_steps', 10000), 
									name="kl_div",
									use_tpu=False,
									kl_anneal="kl_anneal")
		loss = resc_loss + kl_loss
		task_loss = loss
		params_size = model_io_fn.count_params(model_config.scope)
		print("==total encoder params==", params_size)

		if kargs.get("feature_distillation", True):
			universal_feature_a = features.get("input_ids_a_features", None)
			universal_feature_b = features.get("input_ids_b_features", None)
			
			if universal_feature_a is None or universal_feature_b is None:
				tf.logging.info("****** not apply feature distillation *******")
				feature_loss = tf.constant(0.0)
			else:
				feature_a = pooled_feature_dict['feature_a']
				feature_a_shape = bert_utils.get_shape_list(feature_a, expected_rank=[2,3])
				pretrain_feature_a_shape = bert_utils.get_shape_list(universal_feature_a, expected_rank=[2,3])
				if feature_a_shape[-1] != pretrain_feature_a_shape[-1]:
					with tf.variable_scope(scope+"/feature_proj", reuse=tf.AUTO_REUSE):
						proj_feature_a = tf.layers.dense(feature_a, pretrain_feature_a_shape[-1])
					# with tf.variable_scope(scope+"/feature_rec", reuse=tf.AUTO_REUSE):
					# 	proj_feature_a_rec = tf.layers.dense(proj_feature_a, feature_a_shape[-1])
					# loss += tf.reduce_mean(tf.reduce_sum(tf.square(proj_feature_a_rec-feature_a), axis=-1))/float(num_task)
					tf.logging.info("****** apply auto-encoder for feature compression *******")
				else:
					proj_feature_a = feature_a
				feature_a_norm = tf.stop_gradient(tf.sqrt(tf.reduce_sum(tf.pow(proj_feature_a, 2), axis=-1, keepdims=True))+1e-20)
				proj_feature_a /= feature_a_norm

				feature_b = pooled_feature_dict['feature_b'] 
				if feature_a_shape[-1] != pretrain_feature_a_shape[-1]:
					with tf.variable_scope(scope+"/feature_proj", reuse=tf.AUTO_REUSE):
						proj_feature_b = tf.layers.dense(feature_b, pretrain_feature_a_shape[-1])
					# with tf.variable_scope(scope+"/feature_rec", reuse=tf.AUTO_REUSE):
					# 	proj_feature_b_rec = tf.layers.dense(proj_feature_b, feature_a_shape[-1])
					# loss += tf.reduce_mean(tf.reduce_sum(tf.square(proj_feature_b_rec-feature_b), axis=-1))/float(num_task)
					tf.logging.info("****** apply auto-encoder for feature compression *******")
				else:
					proj_feature_b = feature_b

				feature_b_norm = tf.stop_gradient(tf.sqrt(tf.reduce_sum(tf.pow(proj_feature_b, 2), axis=-1, keepdims=True))+1e-20)
				proj_feature_b /= feature_b_norm

				feature_a_distillation = tf.reduce_mean(tf.square(universal_feature_a-proj_feature_a), axis=-1)
				feature_b_distillation = tf.reduce_mean(tf.square(universal_feature_b-proj_feature_b), axis=-1)

				feature_loss = tf.reduce_mean((feature_a_distillation + feature_b_distillation)/2.0)/float(num_task)
				loss += feature_loss
				tf.logging.info("****** apply prertained feature distillation *******")

		if kargs.get("embedding_distillation", True):
			word_embed = model.emb_mat
			random_embed_shape = bert_utils.get_shape_list(word_embed, expected_rank=[2,3])
			print("==random_embed_shape==", random_embed_shape)
			pretrained_embed = kargs.get('pretrained_embed', None)
			if pretrained_embed is None:
				tf.logging.info("****** not apply prertained feature distillation *******")
				embed_loss = tf.constant(0.0)
			else:
				pretrain_embed_shape = bert_utils.get_shape_list(pretrained_embed, expected_rank=[2,3])
				print("==pretrain_embed_shape==", pretrain_embed_shape)
				if random_embed_shape[-1] != pretrain_embed_shape[-1]:
					with tf.variable_scope(scope+"/embedding_proj", reuse=tf.AUTO_REUSE):
						proj_embed = tf.layers.dense(word_embed, pretrain_embed_shape[-1])
				else:
					proj_embed = word_embed
				
				embed_loss = tf.reduce_mean(tf.reduce_mean(tf.square(proj_embed-pretrained_embed), axis=-1))/float(num_task)
				loss += embed_loss
				tf.logging.info("****** apply prertained feature distillation *******")

		if mode == tf.estimator.ModeKeys.TRAIN:
			multi_task_config = kargs.get("multi_task_config", {})
			if multi_task_config.get(task_type, {}).get("lm_augumentation", False):
				print("==apply lm_augumentation==")
				masked_lm_positions = features["masked_lm_positions"]
				masked_lm_ids = features["masked_lm_ids"]
				masked_lm_weights = features["masked_lm_weights"]
				(masked_lm_loss,
				masked_lm_example_loss, 
				masked_lm_log_probs) = pretrain.get_masked_lm_output(
												model_config, 
												model.get_sequence_output(), 
												model.get_embedding_table(),
												masked_lm_positions, 
												masked_lm_ids, 
												masked_lm_weights,
												reuse=model_reuse)

				masked_lm_loss_mask = tf.expand_dims(loss_mask, -1) * tf.ones((1, multi_task_config[task_type]["max_predictions_per_seq"]))
				masked_lm_loss_mask = tf.reshape(masked_lm_loss_mask, (-1, ))

				masked_lm_label_weights = tf.reshape(masked_lm_weights, [-1])
				masked_lm_loss_mask *= tf.cast(masked_lm_label_weights, tf.float32)

				masked_lm_example_loss *= masked_lm_loss_mask# multiply task_mask
				masked_lm_loss = tf.reduce_sum(masked_lm_example_loss) / (1e-10+tf.reduce_sum(masked_lm_loss_mask))
				loss += multi_task_config[task_type]["masked_lm_loss_ratio"]*masked_lm_loss

				masked_lm_label_ids = tf.reshape(masked_lm_ids, [-1])
				
				print(masked_lm_log_probs.get_shape(), "===masked lm log probs===")
				print(masked_lm_label_ids.get_shape(), "===masked lm ids===")
				print(masked_lm_label_weights.get_shape(), "===masked lm mask===")

				lm_acc = build_accuracy(masked_lm_log_probs, masked_lm_label_ids, masked_lm_loss_mask)

		if kargs.get("task_invariant", "no") == "yes":
			print("==apply task adversarial training==")
			with tf.variable_scope(scope+"/dann_task_invariant", reuse=model_reuse):
				(_, 
				task_example_loss, 
				task_logits)  = distillation_utils.feature_distillation(model.get_pooled_output(), 
														1.0, 
														features["task_id"], 
														kargs.get("num_task", 7),
														dropout_prob, 
														True)
				masked_task_example_loss = loss_mask * task_example_loss
				masked_task_loss = tf.reduce_sum(masked_task_example_loss) / (1e-10+tf.reduce_sum(loss_mask))
				loss += kargs.get("task_adversarial", 1e-2) * masked_task_loss

		tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)
		vae_tvars = model_io_fn.get_params("vae", 
										not_storage_params=not_storage_params)

		if mode == tf.estimator.ModeKeys.TRAIN:
			multi_task_config = kargs.get("multi_task_config", {})
			if multi_task_config.get(task_type, {}).get("lm_augumentation", False):
				print("==apply lm_augumentation==")
				masked_lm_pretrain_tvars = model_io_fn.get_params("cls/predictions", 
												not_storage_params=not_storage_params)
				tvars.extend(masked_lm_pretrain_tvars)

		try:
			params_size = model_io_fn.count_params(model_config.scope)
			print("==total params==", params_size)
		except:
			print("==not count params==")
		# print(tvars)
		if load_pretrained == "yes":

			[assignment_map, 
			initialized_variable_names] = model_io_utils.get_assigment_map_from_checkpoint(
															tvars, 
															init_checkpoint,
															exclude_scope="")
			[assignment_map_vae, 
			initialized_variable_names_vae] = model_io_utils.get_assigment_map_from_checkpoint(
															vae_tvars, 
															init_checkpoint,
															exclude_scope="vae/decoder")
			assignment_map.update(assignment_map_vae)
			initialized_variable_names.update(initialized_variable_names_vae)

			model_io_utils.init_pretrained(assignment_map, initialized_variable_names,
										tvars+vae_tvars, init_checkpoint)

		if mode == tf.estimator.ModeKeys.TRAIN:

			train_metric_dict = train_metric(input_ids, 
											sequence_output_logits,
												**kargs)
			return_dict = {
					"loss":loss, 
					"tvars":tvars+vae_tvars
				}
			return_dict["perplexity"] = train_metric_dict['perplexity']
			return_dict["token_acc"] = train_metric_dict['token_acc']
			return_dict["kl_div"] = kl_loss
			if kargs.get("task_invariant", "no") == "yes":
				return_dict["{}_task_loss".format(task_type)] = masked_task_loss
				task_acc = build_accuracy(task_logits, features["task_id"], loss_mask)
				return_dict["{}_task_acc".format(task_type)] = task_acc
			if multi_task_config.get(task_type, {}).get("lm_augumentation", False):
				return_dict["{}_masked_lm_loss".format(task_type)] = masked_lm_loss
				return_dict["{}_masked_lm_acc".format(task_type)] = lm_acc
			if kargs.get("embedding_distillation", True):
				return_dict["embed_loss"] = embed_loss*float(num_task)
			else:
				return_dict["embed_loss"] = task_loss
			if kargs.get("feature_distillation", True):
				return_dict["feature_loss"] = feature_loss*float(num_task)
			else:
				return_dict["feature_loss"] = task_loss
			return_dict["task_loss"] = task_loss
			return return_dict
		elif mode == tf.estimator.ModeKeys.EVAL:
			eval_dict = {
				"loss":loss, 
				"logits":logits,
				"feature":model.get_pooled_output()
			}
			if kargs.get("adversarial", "no") == "adversarial":
				 eval_dict["task_logits"] = task_logits
			return eval_dict
Ejemplo n.º 10
0
def token_generator(config, input_tensor, output_weights, input_ids,
                    input_ori_ids, input_mask, **kargs):

    input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape_list[0]
    seq_length = input_shape_list[1]
    hidden_dims = input_shape_list[2]

    embedding_projection = kargs.get('embedding_projection', None)

    scope = kargs.get('scope', None)
    if scope:
        scope = scope + '/' + 'cls/predictions'
    else:
        scope = 'cls/predictions'

    tf.logging.info("**** mlm generator scope **** %s", str(scope))

    # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        if config.get('ln_type', 'postln') == 'preln':
            input_tensor = albert_modules.layer_norm(input_tensor)
        elif config.get('ln_type', 'postln') == 'postln':
            input_tensor = input_tensor
        else:
            input_tensor = input_tensor

        # if config.get("embedding", "factorized") == "factorized":
        # 	projection_width = config.hidden_size
        # else:
        # 	projection_width = config.embedding_size

        if config.get("embedding", "none_factorized") == "none_factorized":
            projection_width = config.hidden_size
            tf.logging.info("==not using embedding factorized==")
        else:
            projection_width = config.get('embedding_size', config.hidden_size)
            tf.logging.info(
                "==using embedding factorized: embedding size: %s==",
                str(projection_width))

        with tf.variable_scope("transform"):
            input_tensor = tf.layers.dense(
                input_tensor,
                units=projection_width,
                activation=albert_modules.get_activation(config.hidden_act),
                kernel_initializer=albert_modules.create_initializer(
                    config.initializer_range))

            if config.get('ln_type', 'postln') == 'preln':
                input_tensor = input_tensor
            elif config.get('ln_type', 'postln') == 'postln':
                input_tensor = albert_modules.layer_norm(input_tensor)
            else:
                input_tensor = albert_modules.layer_norm(input_tensor)

        if embedding_projection is not None:
            # batch x seq x hidden, embedding x hidden
            print(input_tensor.get_shape(), embedding_projection.get_shape())
            input_tensor = tf.einsum("abc,dc->abd", input_tensor,
                                     embedding_projection)
        else:
            print("==no need for embedding projection==")
            input_tensor = input_tensor

        output_bias = tf.get_variable("output_bias",
                                      shape=[config.vocab_size],
                                      initializer=tf.zeros_initializer())
        # batch x seq x embedding
        logits = tf.einsum("abc,dc->abd", input_tensor, output_weights)
        logits = tf.nn.bias_add(logits, output_bias)

        input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3)
        width = input_shape_list[2]

        if not kargs.get("apply_valid_vocab", False):
            logits = logits
            tf.logging.info("****** normal logits *******")
        elif kargs.get("apply_valid_vocab", False) == 'topk':
            prob, _ = top_k_softmax(logits, kargs.get('topk', 10))
            logits = tf.log(prob + 1e-10)
            tf.logging.info("****** topk logits *******")
        else:
            invalid_size = kargs.get("invalid_size", 106)
            invalid_mask = tf.cast(
                tf.ones((1, invalid_size)) * (-10000), tf.float32)
            valid_mask = tf.cast(
                tf.zeros((1, config.vocab_size - invalid_size)), tf.float32)
            invaild_mask = tf.concat([invalid_mask, valid_mask], axis=-1)
            #
            invaild_mask = tf.expand_dims(invaild_mask,
                                          axis=1)  # batch x seq x vocab
            logits += tf.cast(invaild_mask, tf.float32)
            tf.logging.info(
                "****** only valid logits ******* , invalid size: %s",
                str(invalid_size))

        logits_tempered = tf.nn.log_softmax(logits /
                                            config.get("temperature", 1.0))

        flat_logits_tempered = tf.reshape(logits_tempered,
                                          [batch_size * seq_length, width])

        # flat_logits_tempered_topk = top_k_logits(flat_logits_tempered, int(config.vocab_size/2
        if not kargs.get("greedy", False):
            sampled_logprob_temp, sampled_logprob = gumbel_softmax(
                flat_logits_tempered,
                temperature=1.0,
                samples=config.get('gen_sample', 1),
                greedy=kargs.get("greedy", False))

            samples = tf.argmax(sampled_logprob, axis=1)  # batch x seq
            tf.logging.info("****** normal sample *******")
        else:
            samples = tf.argmax(flat_logits_tempered, axis=-1)
            tf.logging.info("****** greedy sample *******")

        # samples = tf.multinomial(flat_logits_tempered,
        # 						num_samples=config.get('gen_sample', 1),
        # 						output_dtype=tf.int32)

        sampled_binary_mask = kargs.get('sampled_binary_mask', None)
        if sampled_binary_mask is not None:
            label_diff_ids = sampled_binary_mask  # 0 for original and 1 for replace
        else:
            label_diff_ids = tf.not_equal(
                tf.cast(input_ids, tf.int32),
                tf.cast(input_ori_ids,
                        tf.int32)  # 0 for original and 1 for replace
            )
        label_diff_ids = tf.cast(label_diff_ids, tf.float32)
        print(label_diff_ids, "===label diff ids===")
        if not kargs.get('use_tpu', True):
            tf.summary.scalar(
                'label_diff_ids',
                tf.reduce_sum(label_diff_ids * tf.cast(input_mask, tf.float32))
                / tf.reduce_sum(tf.cast(input_mask, tf.float32)))

        if config.get('gen_sample', 1) == 1:
            sampled_input_id = tf.reshape(samples, [batch_size, seq_length])
            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                label_diff_ids = tf.cast(label_diff_ids, tf.float32)
                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id, tf.float32
                ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32)
                sampled_input_id = tf.cast(sampled_input_id, tf.int32)
            elif kargs.get('mask_method', 'only_mask') == 'all_mask':
                input_ori_ids_1 = input_ori_ids
                unk_mask = tf.cast(tf.math.equal(input_ori_ids_1, 100),
                                   tf.float32)  # not replace unk
                cls_mask = tf.cast(tf.math.equal(input_ori_ids_1, 101),
                                   tf.float32)  # not replace cls
                sep_mask = tf.cast(tf.math.equal(input_ori_ids_1, 102),
                                   tf.float32)  # not replace sep
                unsampled_mask = (1 -
                                  (unk_mask + cls_mask + sep_mask)) * tf.cast(
                                      input_mask, tf.float32)
                # unsampled_mask = tf.expand_dims(unsampled_mask, axis=[-1]) # batch x seq x 1
                tf.logging.info("****** all mask sample *******")
                sampled_input_id = unsampled_mask * tf.cast(
                    sampled_input_id, tf.float32
                ) + (1 - unsampled_mask) * tf.cast(input_ori_ids, tf.float32)
                sampled_input_id = tf.cast(sampled_input_id, tf.int32)
        else:
            sampled_input_id = tf.reshape(
                samples, [batch_size, seq_length,
                          config.get('gen_sample', 1)])
            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                tf.logging.info("****** only mask sample *******")
                # batch x seq_length x 1
                label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1)
                label_diff_ids = tf.einsum(
                    'abc,cd->abd', label_diff_ids,
                    tf.ones((1, model_config.get('gen_sample', 1))))
                # batch x seq_length x 1
                input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1)
                input_ori_ids = tf.einsum(
                    'abc,cd->abd', input_ori_ids,
                    tf.ones((1, model_config.get('gen_sample', 1))))
                input_ori_ids = tf.cast(input_ori_ids, tf.float32)

                sampled_input_id = (label_diff_ids) * tf.cast(
                    sampled_input_id,
                    tf.float32) + (1 - input_ori_ids) * label_diff_ids
                sampled_input_id = tf.cast(sampled_input_id, tf.int32)

                input_mask = tf.expand_dims(input_mask, axis=-1)
                input_mask = tf.einsum(
                    'abc,cd->abd', input_mask,
                    tf.ones((1, model_config.get('gen_sample', 1))))
                input_mask = tf.cast(input_mask, tf.float32)

        if not kargs.get('use_tpu', True):

            sampled_not_equal_id = tf.not_equal(
                tf.cast(sampled_input_id, tf.int32),
                tf.cast(input_ori_ids, tf.int32))
            sampled_not_equal = tf.cast(sampled_not_equal_id,
                                        tf.float32) * tf.cast(
                                            input_mask, tf.float32)

            sampled_equal_id = tf.equal(tf.cast(sampled_input_id, tf.int32),
                                        tf.cast(input_ori_ids, tf.int32))

            if kargs.get('mask_method', 'only_mask') == 'only_mask':
                sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32)))
            elif kargs.get('mask_method', 'only_mask') == 'all_mask':
                sampled_equal = tf.cast(sampled_equal_id,
                                        tf.float32) * tf.cast(
                                            unsampled_mask, tf.float32)
                tf.summary.scalar('generator_equal_sample_acc',
                                  tf.reduce_sum(sampled_equal))
                sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32)))
                sampled_equal = tf.reduce_sum(sampled_equal) / (
                    1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32)))
            tf.summary.scalar('generator_sample_acc', sampled_not_equal)

        # sampled_not_equal_id = tf.not_equal(
        # 		tf.cast(sampled_input_id, tf.int32),
        # 		tf.cast(input_ori_ids, tf.int32)
        # )
        # sampled_not_equal = tf.cast(sampled_not_equal_id, tf.float32) * tf.cast(input_mask, tf.float32)
        # sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32)))

        # if not kargs.get('use_tpu', True):
        # 	tf.summary.scalar('generator_sample_acc',
        # 					sampled_not_equal)

        return sampled_input_id
Ejemplo n.º 11
0
    def apply_gradients(self,
                        grads_and_vars,
                        global_step=None,
                        name=None,
                        learning_rate=None):
        """See base class."""

        if learning_rate is None:
            learning_rate = self.learning_rate
            tf.logging.info("***** use default learning rate ***** ",
                            learning_rate)
        else:
            tf.logging.info("***** use provided learning rate ***** ",
                            learning_rate)

        assignments = []
        for (grad, param) in grads_and_vars:
            if grad is None or param is None:
                continue

            param_name = self._get_variable_name(param.name)

            tf.logging.info("***** apply gradients parameter name ***** %s",
                            param_name)
            tf.logging.info("***** param: %s learning rate: %s ***** ",
                            param_name, str(learning_rate))

            shape_list = bert_utils.get_shape_list(param, expected_rank=[1, 2])

            # decay_rate = 1 - tf.pow(tf.cast(tf.train.get_or_create_global_step(), tf.float32) + 1.0, -0.8)
            decay_rate = self.beta_2
            grad_squared = tf.square(grad) + self.epsilon1

            update_scale = self.learning_rate
            # update_scale = self.learning_rate * tf.cast(self._parameter_scale(param), dtype=tf.float32)

            # HACK: Make things dependent on grad.
            # This confounds the XLA rewriter and keeps it from fusing computations
            # across different variables.  This fusion is a bad for HBM usage, since
            # it causes the gradients to persist in memory.
            grad_squared_mean = tf.reduce_mean(grad_squared)
            decay_rate += grad_squared_mean * 1e-30
            update_scale += grad_squared_mean * 1e-30

            # END HACK

            if self._use_factored(shape_list):
                num_rows, num_columns = shape_list

                vr = tf.get_variable(name=param_name + "/adafactor_vr",
                                     shape=[num_rows],
                                     dtype=tf.float32,
                                     trainable=False,
                                     initializer=tf.zeros_initializer())
                vc = tf.get_variable(name=param_name + "/adafactor_vc",
                                     shape=[num_columns],
                                     dtype=tf.float32,
                                     trainable=False,
                                     initializer=tf.zeros_initializer())

                next_vr = decay_rate * vr + (1 - decay_rate) * tf.reduce_mean(
                    grad_squared, 1)
                next_vc = decay_rate * vc + (1 - decay_rate) * tf.reduce_mean(
                    grad_squared, 0)

                long_term_mean = tf.reduce_mean(next_vr, -1, keepdims=True)
                r_factor = tf.rsqrt(next_vr / long_term_mean + self.epsilon1)
                c_factor = tf.rsqrt(next_vc + self.epsilon1)
                update = grad * tf.expand_dims(r_factor, -1) * tf.expand_dims(
                    c_factor, -2)

                assignments.append(
                    vr.assign(next_vr, use_locking=self.use_locking))
                assignments.append(
                    vc.assign(next_vc, use_locking=self.use_locking))
            else:
                v = tf.get_variable(name=param_name + "/adafactor_v",
                                    shape=shape_list,
                                    dtype=tf.float32,
                                    trainable=False,
                                    initializer=tf.zeros_initializer())
                next_v = decay_rate * v + (1 - decay_rate) * grad_squared

                assignments.append(
                    v.assign(next_v, use_locking=self.use_locking))
                update = grad * tf.rsqrt(next_v + self.epsilon1)

            clipping_denom = tf.maximum(
                1.0,
                reduce_rms(update) / self.clipping_rate)
            update /= clipping_denom

            # Do weight decay
            # Just adding the square of the weights to the loss function is *not*
            # the correct way of using L2 regularization/weight decay with Adam,
            # since that will interact with the m and v parameters in strange ways.
            #
            # Instead we want ot decay the weights in a manner that doesn't interact
            # with the m/v parameters. This is equivalent to adding the square
            # # of the weights to the loss with plain (non-momentum) SGD.
            if self._do_use_weight_decay(param_name):
                update += self.weight_decay_rate * param

            update_with_lr = update_scale * update
            next_param = param - update_with_lr

            assignments.append(
                param.assign(next_param, use_locking=self.use_locking))
        return tf.group(*assignments, name=name)
Ejemplo n.º 12
0
def random_input_ids_generation_v1(config, input_ori_ids, input_mask, **kargs):

    mask_id = kargs.get('mask_id', 103)
    valid_vocab = kargs.get('valid_vocab', 105)

    input_ori_ids = tf.cast(input_ori_ids, tf.int32)
    input_mask = tf.cast(input_mask, tf.int32)

    unk_mask = tf.cast(tf.math.equal(input_ori_ids, 100),
                       tf.float32)  # not replace unk
    cls_mask = tf.cast(tf.math.equal(input_ori_ids, 101),
                       tf.float32)  # not replace cls
    sep_mask = tf.cast(tf.math.equal(input_ori_ids, 102),
                       tf.float32)  # not replace sep

    none_replace_mask = unk_mask + cls_mask + sep_mask

    input_shape_list = bert_utils.get_shape_list(input_ori_ids,
                                                 expected_rank=2)
    batch_size = input_shape_list[0]
    seq_length = input_shape_list[1]

    if kargs.get('annealed_mask_prob', False):
        mask_probability = 1 - tf.train.polynomial_decay(
            0.95,
            tf.train.get_or_create_global_step(),
            kargs.get("num_train_steps", 10000) * 0.1,
            end_learning_rate=0.85,
            power=1.0,
            cycle=False)
        tf.logging.info("**** apply annealed_mask_prob **** ")
    else:
        mask_probability = 0.15
        tf.logging.info("**** apply fixed_mask_prob %s **** ",
                        str(mask_probability))

    # must_have_one = tf.cast(tf.expand_dims(tf.eye(seq_length)[4], axis=[0]), tf.int32) # batch x seq_length
    # must_have_one = must_have_one * input_mask * (1 - tf.cast(none_replace_mask, tf.int32))
    sample_probs = tf.ones_like(input_ori_ids) * input_mask * (
        1 - tf.cast(none_replace_mask, tf.int32))
    sample_probs = mask_probability * tf.cast(
        sample_probs, tf.float32
    )  #+ 0.8 * tf.cast(must_have_one, tf.float32) # mask 15% token

    noise_dist = tf.distributions.Bernoulli(probs=sample_probs,
                                            dtype=tf.float32)
    sampled_binary_mask = noise_dist.sample()
    sampled_binary_mask = tf.cast(sampled_binary_mask, tf.float32)

    # mask_binary_probs = 0.8 * sampled_binary_mask # use 80% [mask] for masked token
    # mask_noise_dist = tf.distributions.Bernoulli(probs=mask_binary_probs, dtype=tf.float32)
    # sampled_mask_binary_mask = mask_noise_dist.sample()
    # sampled_mask_binary_mask = tf.cast(sampled_mask_binary_mask, tf.float32)

    # replace_binary_probs = 0.5 * (sampled_binary_mask - sampled_mask_binary_mask) # use 10% [mask] to replace token
    # replace_noise_dist = tf.distributions.Bernoulli(probs=replace_binary_probs, dtype=tf.float32)
    # sampled_replace_binary_mask = replace_noise_dist.sample()
    # sampled_replace_binary_mask = tf.cast(sampled_replace_binary_mask, tf.float32)

    # ori_binary_probs = 1.0 * (sampled_binary_mask - sampled_mask_binary_mask - sampled_replace_binary_mask)
    # ori_noise_dist = tf.distributions.Bernoulli(probs=ori_binary_probs, dtype=tf.float32)
    # sampled_ori_binary_mask = ori_noise_dist.sample()
    # sampled_ori_binary_mask = tf.cast(sampled_ori_binary_mask, tf.float32)

    replace_binary_probs = 0.1 * (sampled_binary_mask
                                  )  # use 10% [mask] to replace token
    replace_noise_dist = tf.distributions.Bernoulli(probs=replace_binary_probs,
                                                    dtype=tf.float32)
    sampled_replace_binary_mask = replace_noise_dist.sample()
    sampled_replace_binary_mask = tf.cast(sampled_replace_binary_mask,
                                          tf.float32)

    ori_binary_probs = 0.1 * (sampled_binary_mask -
                              sampled_replace_binary_mask)
    ori_noise_dist = tf.distributions.Bernoulli(probs=ori_binary_probs,
                                                dtype=tf.float32)
    sampled_ori_binary_mask = ori_noise_dist.sample()
    sampled_ori_binary_mask = tf.cast(sampled_ori_binary_mask, tf.float32)

    # mask_binary_probs = 0.85 * (sampled_binary_mask - sampled_replace_binary_mask - sampled_ori_binary_mask) # use 80% [mask] for masked token
    # mask_noise_dist = tf.distributions.Bernoulli(probs=mask_binary_probs, dtype=tf.float32)
    # sampled_mask_binary_mask = mask_noise_dist.sample()
    # sampled_mask_binary_mask = tf.cast(sampled_mask_binary_mask, tf.float32)

    sampled_mask_binary_mask = (sampled_binary_mask -
                                sampled_replace_binary_mask -
                                sampled_ori_binary_mask)
    sampled_mask_binary_mask = tf.cast(sampled_mask_binary_mask, tf.float32)

    # sampled_replace_binary_mask *=  (1 - tf.cast(none_replace_mask, tf.float32))
    # sampled_replace_binary_mask *= tf.cast(input_mask, tf.float32)

    # sampled_mask_binary_mask *=  (1 - tf.cast(none_replace_mask, tf.float32))
    # sampled_mask_binary_mask *= tf.cast(input_mask, tf.float32)

    # sampled_ori_binary_mask *=  (1 - tf.cast(none_replace_mask, tf.float32))
    # sampled_ori_binary_mask *= tf.cast(input_mask, tf.float32)

    vocab_sample_logits = tf.random.uniform(
        [batch_size, seq_length, config.vocab_size],
        minval=0.0,
        maxval=1.0,
        dtype=tf.float32)

    vocab_sample_logits = tf.nn.log_softmax(vocab_sample_logits)
    flatten_vocab_sample_logits = tf.reshape(vocab_sample_logits,
                                             [batch_size * seq_length, -1])

    sampled_logprob_temp, sampled_logprob = gumbel_softmax(
        flatten_vocab_sample_logits,
        temperature=0.1,
        samples=config.get('gen_sample', 1))

    sample_vocab_ids = tf.argmax(sampled_logprob, axis=1)  # batch x seq

    # sample_vocab_ids = tf.multinomial(flatten_vocab_sample_logits,
    # 							num_samples=config.get('gen_sample', 1),
    # 							output_dtype=tf.int32)

    sample_vocab_ids = tf.reshape(sample_vocab_ids, [batch_size, seq_length])
    sample_vocab_ids = tf.cast(sample_vocab_ids, tf.float32)
    input_ori_ids = tf.cast(input_ori_ids, tf.float32)

    output_input_ids = mask_id * tf.cast(
        sampled_mask_binary_mask, tf.float32) * tf.ones_like(input_ori_ids)
    output_input_ids += sample_vocab_ids * tf.cast(sampled_replace_binary_mask,
                                                   tf.float32)
    output_input_ids += (
        1 - tf.cast(sampled_mask_binary_mask + sampled_replace_binary_mask,
                    tf.float32)) * input_ori_ids
    output_sampled_binary_mask = sampled_mask_binary_mask + sampled_replace_binary_mask + sampled_ori_binary_mask

    output_sampled_binary_mask = tf.cast(output_sampled_binary_mask, tf.int32)

    return [tf.cast(output_input_ids, tf.int32), output_sampled_binary_mask]
Ejemplo n.º 13
0
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)
        model = model_api(model_config,
                          features,
                          labels,
                          tf.estimator.ModeKeys.PREDICT,
                          target,
                          reuse=model_reuse,
                          cnn_type=model_config.get('cnn_type', 'bi_dgcnn'),
                          **kargs)

        dropout_prob = 0.0
        is_training = False

        with tf.variable_scope(model_config.scope + "/feature_output",
                               reuse=tf.AUTO_REUSE):
            hidden_size = bert_utils.get_shape_list(model.get_pooled_output(),
                                                    expected_rank=2)[-1]
            sentence_pres = model.get_pooled_output()

            sentence_pres = tf.layers.dense(
                sentence_pres,
                128,
                use_bias=True,
                activation=tf.tanh,
                kernel_initializer=tf.truncated_normal_initializer(
                    stddev=0.01))

            # sentence_pres = tf.layers.dense(
            # 				model.get_pooled_output(),
            # 				hidden_size,
            # 				use_bias=None,
            # 				activation=tf.nn.relu,
            # 				kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))

            # sentence_pres = tf.layers.dense(
            # 				sentence_pres,
            # 				hidden_size,
            # 				use_bias=None,
            # 				activation=None,
            # 				kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))

            # hidden_size = bert_utils.get_shape_list(model.get_pooled_output(), expected_rank=2)[-1]
            # sentence_pres = tf.layers.dense(
            # 			model.get_pooled_output(),
            # 			hidden_size,
            # 			use_bias=True,
            # 			activation=tf.tanh,
            # 			kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
            # feature_output_a = tf.layers.dense(
            # 				model.get_pooled_output(),
            # 				hidden_size,
            # 				kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
            # feature_output_a = tf.nn.dropout(feature_output_a, keep_prob=1 - dropout_prob)
            # feature_output_a += model.get_pooled_output()
            # sentence_pres = tf.layers.dense(
            # 				feature_output_a,
            # 				hidden_size,
            # 				kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            # 				activation=tf.tanh)

        if kargs.get('apply_head_proj', False):
            with tf.variable_scope(model_config.scope + "/head_proj",
                                   reuse=tf.AUTO_REUSE):
                sentence_pres = simclr_utils.projection_head(
                    sentence_pres,
                    is_training,
                    head_proj_dim=128,
                    num_nlh_layers=1,
                    head_proj_mode='nonlinear',
                    name='head_contrastive')

        l2_sentence_pres = tf.nn.l2_normalize(sentence_pres + 1e-20, axis=-1)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)

        try:
            params_size = model_io_fn.count_params(model_config.scope)
            print("==total params==", params_size)
        except:
            print("==not count params==")
        print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        estimator_spec = tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions={
                'sentence_pres': l2_sentence_pres,
                # "before_l2":sentence_pres
            },
            export_outputs={
                "output":
                tf.estimator.export.PredictOutput({
                    'sentence_pres':
                    l2_sentence_pres,
                    # "before_l2":sentence_pres
                })
            })
        return estimator_spec
Ejemplo n.º 14
0
    def model_fn(features, labels, mode):

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]

        input_shape = bert_utils.get_shape_list(input_ids, expected_rank=3)
        batch_size = input_shape[0]
        choice_num = input_shape[1]
        seq_length = input_shape[2]

        input_ids = tf.reshape(input_ids,
                               [batch_size * choice_num, seq_length])
        input_mask = tf.reshape(input_mask,
                                [batch_size * choice_num, seq_length])
        segment_ids = tf.reshape(segment_ids,
                                 [batch_size * choice_num, seq_length])

        if mode == tf.estimator.ModeKeys.TRAIN:
            hidden_dropout_prob = model_config.hidden_dropout_prob
            attention_probs_dropout_prob = model_config.attention_probs_dropout_prob
            dropout_prob = model_config.dropout_prob
        else:
            hidden_dropout_prob = 0.0
            attention_probs_dropout_prob = 0.0
            dropout_prob = 0.0

        model = bert.Bert(model_config)
        model.build_embedder(input_ids,
                             segment_ids,
                             hidden_dropout_prob,
                             attention_probs_dropout_prob,
                             reuse=reuse)
        model.build_encoder(input_ids,
                            input_mask,
                            hidden_dropout_prob,
                            attention_probs_dropout_prob,
                            reuse=reuse)
        model.build_pooler(reuse=reuse)

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=reuse):
            (loss, per_example_loss,
             logits) = classifier.multi_choice_classifier(
                 model_config, model.get_pooled_output(), num_labels,
                 label_ids, dropout_prob)

        # model_io_fn = model_io.ModelIO(model_io_config)
        pretrained_tvars = model_io_fn.get_params(model_config.scope)
        if load_pretrained:
            model_io_fn.load_pretrained(pretrained_tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        tvars = model_io_fn.get_params(scope,
                                       not_storage_params=not_storage_params)
        model_io_fn.set_saver(var_lst=tvars)
        if mode == tf.estimator.ModeKeys.TRAIN:
            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                optimizer_fn = optimizer.Optimizer(opt_config)
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps)

                return [train_op, loss, per_example_loss, logits]

        else:
            model_io_fn.print_params(tvars, string=", trainable params")
            return [loss, loss, per_example_loss, logits]
Ejemplo n.º 15
0
    def build_output_logits(self, **kargs):
        layer_num = kargs.get("layer_num", -1)
        self.sequence_output = self.get_encoder_layers(layer_num)
        input_shape_list = bert_utils.get_shape_list(self.sequence_output,
                                                     expected_rank=3)
        batch_size = input_shape_list[0]
        seq_length = input_shape_list[1]
        hidden_dims = input_shape_list[2]

        embedding_projection = kargs.get('embedding_projection', None)

        scope = kargs.get('scope', None)
        if scope:
            scope = scope + '/' + 'cls/predictions'
        else:
            scope = 'cls/predictions'

        tf.logging.info("**** mlm generator scope **** %s", str(scope))

        # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            if self.config.get('ln_type', 'postln') == 'preln':
                input_tensor = bert_modules.layer_norm(self.sequence_output)
                tf.logging.info("**** pre ln doing layer norm ****")
            elif self.config.get('ln_type', 'postln') == 'postln':
                input_tensor = self.sequence_output
                tf.logging.info("**** post ln ****")
            else:
                input_tensor = self.sequence_output
                tf.logging.info("**** post ln ****")

            # if config.get("embedding", "factorized") == "factorized":
            # 	projection_width = config.hidden_size
            # else:
            # 	projection_width = config.embedding_size

            if self.config.get("embedding",
                               "none_factorized") == "none_factorized":
                projection_width = self.config.hidden_size
                tf.logging.info("==not using embedding factorized==")
            else:
                projection_width = self.config.get('embedding_size',
                                                   self.config.hidden_size)
                tf.logging.info(
                    "==using embedding factorized: embedding size: %s==",
                    str(projection_width))

            with tf.variable_scope("transform"):
                input_tensor = tf.layers.dense(
                    input_tensor,
                    units=projection_width,
                    activation=bert_modules.get_activation(
                        self.config.hidden_act),
                    kernel_initializer=bert_modules.create_initializer(
                        self.config.initializer_range))

                if self.config.get('ln_type', 'postln') == 'preln':
                    input_tensor = input_tensor
                    tf.logging.info("**** pre ln ****")
                elif self.config.get('ln_type', 'postln') == 'postln':
                    input_tensor = bert_modules.layer_norm(input_tensor)
                    tf.logging.info("**** post ln doing layer norm ****")
                else:
                    input_tensor = bert_modules.layer_norm(input_tensor)
                    tf.logging.info("**** post ln doing layer norm ****")

            if embedding_projection is not None:
                # batch x seq x hidden, embedding x hidden
                print(input_tensor.get_shape(),
                      embedding_projection.get_shape())
                input_tensor = tf.einsum("abc,dc->abd", input_tensor,
                                         embedding_projection)
            else:
                print("==no need for embedding projection==")
                input_tensor = input_tensor

            output_bias = tf.get_variable("output_bias",
                                          shape=[self.config.vocab_size],
                                          initializer=tf.zeros_initializer())
            # batch x seq x embedding
            logits = tf.einsum("abc,dc->abd", input_tensor,
                               self.embedding_table)
            self.logits = tf.nn.bias_add(logits, output_bias)
Ejemplo n.º 16
0
    def model_fn(features, labels, mode):

        task_type = kargs.get("task_type", "cls")

        label_ids = features["{}_label_ids".format(task_type)]

        num_task = kargs.get('num_task', 1)

        model_io_fn = model_io.ModelIO(model_io_config)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        if kargs.get("get_pooled_output", "pooled_output") == "pooled_output":
            pooled_feature = model.get_pooled_output()
        elif kargs.get("get_pooled_output", "task_output") == "task_output":
            pooled_feature_dict = model.get_task_output()
            pooled_feature = pooled_feature_dict['pooled_feature']

        loss_mask = tf.cast(features["{}_loss_multipiler".format(task_type)],
                            tf.float32)
        loss = tf.constant(0.0)

        params_size = model_io_fn.count_params(model_config.scope)
        print("==total encoder params==", params_size)

        if kargs.get("feature_distillation", True):
            universal_feature_a = features.get("input_ids_a_features", None)
            universal_feature_b = features.get("input_ids_b_features", None)

            if universal_feature_a is None or universal_feature_b is None:
                tf.logging.info(
                    "****** not apply feature distillation *******")
                feature_loss = tf.constant(0.0)
            else:
                feature_a = pooled_feature_dict['feature_a']
                feature_a_shape = bert_utils.get_shape_list(
                    feature_a, expected_rank=[2, 3])
                pretrain_feature_a_shape = bert_utils.get_shape_list(
                    universal_feature_a, expected_rank=[2, 3])
                if feature_a_shape[-1] != pretrain_feature_a_shape[-1]:
                    with tf.variable_scope(scope + "/feature_proj",
                                           reuse=tf.AUTO_REUSE):
                        proj_feature_a = tf.layers.dense(
                            feature_a, pretrain_feature_a_shape[-1])
                    # with tf.variable_scope(scope+"/feature_rec", reuse=tf.AUTO_REUSE):
                    # 	proj_feature_a_rec = tf.layers.dense(proj_feature_a, feature_a_shape[-1])
                    # loss += tf.reduce_mean(tf.reduce_sum(tf.square(proj_feature_a_rec-feature_a), axis=-1))/float(num_task)
                    tf.logging.info(
                        "****** apply auto-encoder for feature compression *******"
                    )
                else:
                    proj_feature_a = feature_a
                feature_a_norm = tf.stop_gradient(
                    tf.sqrt(
                        tf.reduce_sum(tf.pow(proj_feature_a, 2),
                                      axis=-1,
                                      keepdims=True)) + 1e-20)
                proj_feature_a /= feature_a_norm

                feature_b = pooled_feature_dict['feature_b']
                if feature_a_shape[-1] != pretrain_feature_a_shape[-1]:
                    with tf.variable_scope(scope + "/feature_proj",
                                           reuse=tf.AUTO_REUSE):
                        proj_feature_b = tf.layers.dense(
                            feature_b, pretrain_feature_a_shape[-1])
                    # with tf.variable_scope(scope+"/feature_rec", reuse=tf.AUTO_REUSE):
                    # 	proj_feature_b_rec = tf.layers.dense(proj_feature_b, feature_a_shape[-1])
                    # loss += tf.reduce_mean(tf.reduce_sum(tf.square(proj_feature_b_rec-feature_b), axis=-1))/float(num_task)
                    tf.logging.info(
                        "****** apply auto-encoder for feature compression *******"
                    )
                else:
                    proj_feature_b = feature_b

                feature_b_norm = tf.stop_gradient(
                    tf.sqrt(
                        tf.reduce_sum(tf.pow(proj_feature_b, 2),
                                      axis=-1,
                                      keepdims=True)) + 1e-20)
                proj_feature_b /= feature_b_norm

                feature_a_distillation = tf.reduce_mean(
                    tf.square(universal_feature_a - proj_feature_a), axis=-1)
                feature_b_distillation = tf.reduce_mean(
                    tf.square(universal_feature_b - proj_feature_b), axis=-1)

                feature_loss = tf.reduce_mean(
                    (feature_a_distillation + feature_b_distillation) /
                    2.0) / float(num_task)
                loss += feature_loss
                tf.logging.info(
                    "****** apply prertained feature distillation *******")

        if kargs.get("embedding_distillation", True):
            word_embed = model.emb_mat
            random_embed_shape = bert_utils.get_shape_list(
                word_embed, expected_rank=[2, 3])
            print("==random_embed_shape==", random_embed_shape)
            pretrained_embed = kargs.get('pretrained_embed', None)
            if pretrained_embed is None:
                tf.logging.info(
                    "****** not apply prertained feature distillation *******")
                embed_loss = tf.constant(0.0)
            else:
                pretrain_embed_shape = bert_utils.get_shape_list(
                    pretrained_embed, expected_rank=[2, 3])
                print("==pretrain_embed_shape==", pretrain_embed_shape)
                if random_embed_shape[-1] != pretrain_embed_shape[-1]:
                    with tf.variable_scope(scope + "/embedding_proj",
                                           reuse=tf.AUTO_REUSE):
                        proj_embed = tf.layers.dense(word_embed,
                                                     pretrain_embed_shape[-1])
                else:
                    proj_embed = word_embed

                embed_loss = tf.reduce_mean(
                    tf.reduce_mean(tf.square(proj_embed - pretrained_embed),
                                   axis=-1)) / float(num_task)
                loss += embed_loss
                tf.logging.info(
                    "****** apply prertained feature distillation *******")

        with tf.variable_scope(scope + "/{}/classifier".format(task_type),
                               reuse=task_layer_reuse):
            (_, per_example_loss,
             logits) = classifier.classifier(model_config, pooled_feature,
                                             num_labels, label_ids,
                                             dropout_prob)

        loss_mask = tf.cast(features["{}_loss_multipiler".format(task_type)],
                            tf.float32)
        masked_per_example_loss = per_example_loss * loss_mask
        task_loss = tf.reduce_sum(masked_per_example_loss) / (
            1e-10 + tf.reduce_sum(loss_mask))
        loss += task_loss

        if mode == tf.estimator.ModeKeys.TRAIN:
            multi_task_config = kargs.get("multi_task_config", {})
            if multi_task_config[task_type].get("lm_augumentation", False):
                print("==apply lm_augumentation==")
                masked_lm_positions = features["masked_lm_positions"]
                masked_lm_ids = features["masked_lm_ids"]
                masked_lm_weights = features["masked_lm_weights"]
                (masked_lm_loss, masked_lm_example_loss,
                 masked_lm_log_probs) = pretrain.get_masked_lm_output(
                     model_config,
                     model.get_sequence_output(),
                     model.get_embedding_table(),
                     masked_lm_positions,
                     masked_lm_ids,
                     masked_lm_weights,
                     reuse=model_reuse)

                masked_lm_loss_mask = tf.expand_dims(loss_mask, -1) * tf.ones(
                    (1,
                     multi_task_config[task_type]["max_predictions_per_seq"]))
                masked_lm_loss_mask = tf.reshape(masked_lm_loss_mask, (-1, ))

                masked_lm_label_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_loss_mask *= tf.cast(masked_lm_label_weights,
                                               tf.float32)

                masked_lm_example_loss *= masked_lm_loss_mask  # multiply task_mask
                masked_lm_loss = tf.reduce_sum(masked_lm_example_loss) / (
                    1e-10 + tf.reduce_sum(masked_lm_loss_mask))
                loss += multi_task_config[task_type][
                    "masked_lm_loss_ratio"] * masked_lm_loss

                masked_lm_label_ids = tf.reshape(masked_lm_ids, [-1])

                print(masked_lm_log_probs.get_shape(),
                      "===masked lm log probs===")
                print(masked_lm_label_ids.get_shape(), "===masked lm ids===")
                print(masked_lm_label_weights.get_shape(),
                      "===masked lm mask===")

                lm_acc = build_accuracy(masked_lm_log_probs,
                                        masked_lm_label_ids,
                                        masked_lm_loss_mask)

        if kargs.get("task_invariant", "no") == "yes":
            print("==apply task adversarial training==")
            with tf.variable_scope(scope + "/dann_task_invariant",
                                   reuse=model_reuse):
                (_, task_example_loss,
                 task_logits) = distillation_utils.feature_distillation(
                     model.get_pooled_output(), 1.0, features["task_id"],
                     kargs.get("num_task", 7), dropout_prob, True)
                masked_task_example_loss = loss_mask * task_example_loss
                masked_task_loss = tf.reduce_sum(masked_task_example_loss) / (
                    1e-10 + tf.reduce_sum(loss_mask))
                loss += kargs.get("task_adversarial", 1e-2) * masked_task_loss

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)

        if mode == tf.estimator.ModeKeys.TRAIN:
            multi_task_config = kargs.get("multi_task_config", {})
            if multi_task_config[task_type].get("lm_augumentation", False):
                print("==apply lm_augumentation==")
                masked_lm_pretrain_tvars = model_io_fn.get_params(
                    "cls/predictions", not_storage_params=not_storage_params)
                tvars.extend(masked_lm_pretrain_tvars)

        try:
            params_size = model_io_fn.count_params(model_config.scope)
            print("==total params==", params_size)
        except:
            print("==not count params==")
        # print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            acc = build_accuracy(logits, label_ids, loss_mask)

            return_dict = {
                "loss": loss,
                "logits": logits,
                "task_num": tf.reduce_sum(loss_mask),
                "tvars": tvars
            }
            return_dict["{}_acc".format(task_type)] = acc
            if kargs.get("task_invariant", "no") == "yes":
                return_dict["{}_task_loss".format(
                    task_type)] = masked_task_loss
                task_acc = build_accuracy(task_logits, features["task_id"],
                                          loss_mask)
                return_dict["{}_task_acc".format(task_type)] = task_acc
            if multi_task_config[task_type].get("lm_augumentation", False):
                return_dict["{}_masked_lm_loss".format(
                    task_type)] = masked_lm_loss
                return_dict["{}_masked_lm_acc".format(task_type)] = lm_acc
            if kargs.get("embedding_distillation", True):
                return_dict["embed_loss"] = embed_loss * float(num_task)
            else:
                return_dict["embed_loss"] = task_loss
            if kargs.get("feature_distillation", True):
                return_dict["feature_loss"] = feature_loss * float(num_task)
            else:
                return_dict["feature_loss"] = task_loss
            return_dict["task_loss"] = task_loss
            return return_dict
        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_dict = {
                "loss": loss,
                "logits": logits,
                "feature": model.get_pooled_output()
            }
            if kargs.get("adversarial", "no") == "adversarial":
                eval_dict["task_logits"] = task_logits
            return eval_dict
Ejemplo n.º 17
0
def discriminator_metric_eval(input_dict):

    d_out_real = input_dict['true_logits']
    d_out_fake = input_dict['fake_logits']

    input_shape_list = bert_utils.get_shape_list(d_out_real, expected_rank=[2])
    batch_size = input_shape_list[0]

    true_labels = tf.cast(tf.ones(batch_size), tf.int32)
    fake_labels = tf.cast(tf.zeros(batch_size), tf.int32)

    pred_true_label = tf.argmax(d_out_real, axis=-1)
    pred_fake_label = tf.argmax(d_out_fake, axis=-1)

    all_pred_label = tf.concat([pred_true_label, pred_fake_label], axis=0)
    all_true_label = tf.concat([true_labels, fake_labels], axis=0)

    if not kargs.get('use_tpu', True):
        discriminator_f1 = tf_metrics.f1(all_true_label,
                                         all_pred_label,
                                         2,
                                         average="macro")
        discriminator_precison = tf_metrics.precision(all_true_label,
                                                      all_pred_label,
                                                      2,
                                                      average="macro")
        discriminator_recall = tf_metrics.recall(all_true_label,
                                                 all_pred_label,
                                                 2,
                                                 average="macro")
        discriminator_f1_original = tf_metrics.f1(all_true_label,
                                                  all_pred_label,
                                                  2,
                                                  pos_indices=[0],
                                                  average="macro")
        discriminator_f1_replaced = tf_metrics.f1(all_true_label,
                                                  all_pred_label,
                                                  2,
                                                  pos_indices=[1],
                                                  average="macro")
        discriminator_precision_original = tf_metrics.precision(
            all_true_label,
            all_pred_label,
            2,
            pos_indices=[0],
            average="macro")
        discriminator_precision_replaced = tf_metrics.precision(
            all_true_label,
            all_pred_label,
            2,
            pos_indices=[1],
            average="macro")
        discriminator_recall_original = tf_metrics.recall(all_true_label,
                                                          all_pred_label,
                                                          2,
                                                          pos_indices=[0],
                                                          average="macro")
        discriminator_recall_replaced = tf_metrics.recall(all_true_label,
                                                          all_pred_label,
                                                          2,
                                                          pos_indices=[1],
                                                          average="macro")
        output_dict['discriminator_f1'] = discriminator_f1
        output_dict['discriminator_precison'] = discriminator_precison
        output_dict['discriminator_recall'] = discriminator_recall
        output_dict['discriminator_f1_original'] = discriminator_f1_original
        output_dict['discriminator_f1_replaced'] = discriminator_f1_replaced
        output_dict[
            'discriminator_precision_original'] = discriminator_precision_original
        output_dict[
            'discriminator_precision_replaced'] = discriminator_precision_replaced
        output_dict[
            'discriminator_recall_original'] = discriminator_recall_original
        output_dict[
            'discriminator_recall_replaced'] = discriminator_recall_replaced
    else:
        discriminator_recall = tf.compat.v1.metrics.recall(
            tf.one_hot(all_true_label, 2), tf.one_hot(all_pred_label, 2))

        discriminator_precison = tf.compat.v1.metrics.precision(
            tf.one_hot(all_true_label, 2), tf.one_hot(all_pred_label, 2))
        discriminator_f1 = tf_metrics.f1(all_true_label,
                                         all_pred_label,
                                         2,
                                         average="macro")
        discriminator_f1_original = tf_metrics.f1(all_true_label,
                                                  all_pred_label,
                                                  2,
                                                  pos_indices=[0],
                                                  average="macro")
        discriminator_f1_replaced = tf_metrics.f1(all_true_label,
                                                  all_pred_label,
                                                  2,
                                                  pos_indices=[1],
                                                  average="macro")
        discriminator_precision_original = tf_metrics.precision(
            all_true_label,
            all_pred_label,
            2,
            pos_indices=[0],
            average="macro")
        discriminator_precision_replaced = tf_metrics.precision(
            all_true_label,
            all_pred_label,
            2,
            pos_indices=[1],
            average="macro")
        discriminator_recall_original = tf_metrics.recall(all_true_label,
                                                          all_pred_label,
                                                          2,
                                                          pos_indices=[0],
                                                          average="macro")
        discriminator_recall_replaced = tf_metrics.recall(all_true_label,
                                                          all_pred_label,
                                                          2,
                                                          pos_indices=[1],
                                                          average="macro")

        output_dict['discriminator_f1_original'] = discriminator_f1_original
        output_dict['discriminator_f1_replaced'] = discriminator_f1_replaced
        output_dict[
            'discriminator_precision_original'] = discriminator_precision_original
        output_dict[
            'discriminator_precision_replaced'] = discriminator_precision_replaced
        output_dict[
            'discriminator_recall_original'] = discriminator_recall_original
        output_dict[
            'discriminator_recall_replaced'] = discriminator_recall_replaced
        output_dict['discriminator_f1'] = discriminator_f1
        output_dict['discriminator_precison'] = discriminator_precison
        output_dict['discriminator_recall'] = discriminator_recall
    return output_dict
Ejemplo n.º 18
0
def sample_sequence_without_cache(model_api,
                                  model_config,
                                  mode,
                                  features,
                                  target="",
                                  start_token=101,
                                  batch_size=None,
                                  seq_length=None,
                                  context=None,
                                  temperature=1,
                                  n_samples=1,
                                  top_k=0,
                                  end_token=102,
                                  greedy_or_sample="sample",
                                  gumbel_temp=0.01,
                                  estimator="straight_through",
                                  back_prop=True,
                                  swap_memory=True,
                                  max_seq_length=512,
                                  **kargs):

    input_shape = bert_utils.get_shape_list(features["input_ids"],
                                            expected_rank=[2, 3])
    batch_size = input_shape[0]
    seq_length = input_shape[1]

    actual_length = seq_length

    if context is None:
        assert start_token is not None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)
        context = tf.cast(context, tf.int32)
        context_shape = bert_utils.get_shape_list(context, expected_rank=[2])
        print(context.get_shape(), "===init context shape===")
    else:
        context = tf.cast(context, tf.int32)
        context_shape = bert_utils.get_shape_list(context, expected_rank=[2])
        batch_size = input_shape[0]

    samples = tf.cast(tf.zeros((batch_size, actual_length)), tf.int32)
    end_mask = tf.expand_dims(tf.one_hot(actual_length - 1, actual_length),
                              axis=(0))
    samples += end_token * tf.cast(
        end_mask, tf.int32)  # make sure last token is end token

    start_mask = tf.one_hot(tf.range(0, context_shape[1]), actual_length)
    samples += tf.cast(
        tf.einsum("ab,bc->ac", tf.cast(context, tf.float32),
                  tf.cast(start_mask, tf.float32)), tf.int32)

    segment_ids = tf.cast(
        tf.zeros((batch_size, actual_length - context_shape[1])), tf.int32)

    if kargs.get("mask_type", "left2right") == 'left2right':
        segment_ids = tf.concat([
            tf.cast(tf.zeros(
                (batch_size, context_shape[1])), tf.int32), segment_ids
        ],
                                axis=-1)
    elif kargs.get("mask_type", "left2right") == 'seq2seq':
        segment_ids = tf.concat([
            tf.cast(tf.ones(
                (batch_size, context_shape[1])), tf.int32), segment_ids
        ],
                                axis=-1)

    logits = tf.cast(tf.zeros((batch_size, actual_length)), tf.float32)

    input_mask = tf.cast(
        tf.zeros((batch_size, actual_length - context_shape[1])), tf.int32)
    input_mask = tf.concat([
        tf.cast(tf.ones((batch_size, context_shape[1])), tf.int32), input_mask
    ],
                           axis=-1)

    if estimator in ["straight_through", "soft"]:
        gumbel_probs = tf.zeros((batch_size, actual_length - context_shape[1],
                                 model_config.vocab_size))

        start_probs = context
        start_one_hot = tf.one_hot(start_probs, model_config.vocab_size)
        gumbel_probs = tf.concat(
            [tf.cast(start_one_hot, tf.float32), gumbel_probs], axis=1)

    def step(step, tokens, input_mask, segment_ids):

        token_shape = bert_utils.get_shape_list(tokens, expected_rank=[2, 3])

        features = {}
        features['input_ids'] = tokens
        features['segment_ids'] = segment_ids
        features['input_mask'] = input_mask

        inference_model = model_api(model_config,
                                    features, [],
                                    mode,
                                    target,
                                    reuse=tf.AUTO_REUSE,
                                    **kargs)

        logits = inference_model.get_sequence_output_logits()

        return {'logits': logits}

    with tf.name_scope('sample_sequence'):

        def get_samples_logits(samples, logits):
            batch_idxs = tf.range(0, tf.shape(samples)[0])
            batch_idxs = tf.expand_dims(tf.cast(batch_idxs, tf.int32), 1)
            samples = tf.expand_dims(tf.cast(samples, tf.int32), 1)

            idxs = tf.concat([batch_idxs, samples], 1)
            sample_logits = tf.gather_nd(logits, idxs)
            return sample_logits

        def body(i, samples, input_mask, segment_ids, logits):
            next_outputs = step(i, samples, input_mask, segment_ids)

            logits_mask = tf.expand_dims(tf.one_hot(i - 1, actual_length),
                                         axis=(0))  # [1, seq]

            next_logits = tf.reduce_sum(
                next_outputs['logits'] *
                tf.cast(tf.expand_dims(logits_mask, axis=-1), tf.float32),
                axis=1)

            next_logits = next_logits / tf.to_float(temperature)

            next_logits = tf.nn.log_softmax(next_logits, axis=-1)
            if greedy_or_sample == "sample":
                next_samples = tf.multinomial(next_logits,
                                              num_samples=1,
                                              output_dtype=tf.int32)
                next_samples = tf.squeeze(next_samples, axis=-1)
            elif greedy_or_sample == "greedy":
                next_samples = tf.argmax(next_logits, axis=-1)
            else:
                next_samples = tf.argmax(next_logits, axis=-1)
            next_samples = tf.cast(next_samples, tf.int32)
            print(next_samples.get_shape(), "==sample shape==")

            print(tf.one_hot(i, actual_length).get_shape(), "====shhhhape===")

            sample_mask = tf.expand_dims(tf.one_hot(i, actual_length),
                                         axis=(0))  # [1, seq]

            print(sample_mask.get_shape(), "==sample mask shape==")
            print(samples.get_shape(), "==samples shape==")
            samples += tf.cast(sample_mask, tf.int32) * tf.cast(
                tf.expand_dims(next_samples, axis=-1), tf.int32)

            next_sample_logits = get_samples_logits(next_samples, next_logits)
            print(next_sample_logits.get_shape(),
                  "===next sampleslogis shape==")
            logits += tf.cast(sample_mask, tf.float32) * tf.expand_dims(
                next_sample_logits, axis=-1)

            input_mask += tf.cast(sample_mask, tf.int32) * tf.cast(
                tf.expand_dims(tf.ones_like(next_samples), axis=-1), tf.int32)

            return [i + 1, samples, input_mask, segment_ids, logits]

        def gumbel_st_body(i, samples, gumbel_probs, input_mask, segment_ids,
                           logits):

            next_outputs = step(i, gumbel_probs, input_mask, segment_ids)

            # next_logits = next_outputs['logits'][:, i-1, :]  / tf.to_float(temperature)
            logits_mask = tf.expand_dims(tf.one_hot(i - 1, actual_length),
                                         axis=(0))  # [1, seq]
            next_logits = tf.reduce_sum(
                next_outputs['logits'] *
                tf.cast(tf.expand_dims(logits_mask, axis=-1), tf.float32),
                axis=1)
            next_logits = next_logits / tf.to_float(temperature)
            next_logits = tf.nn.log_softmax(next_logits, axis=-1)

            next_gumbel_probs, _ = gumbel_softmax(next_logits,
                                                  gumbel_temp,
                                                  gumbel_samples=None,
                                                  samples=1)
            next_samples = tf.cast(tf.argmax(next_gumbel_probs, axis=1),
                                   tf.int32)
            next_samples_onehot = tf.one_hot(next_samples,
                                             model_config.vocab_size,
                                             axis=1)  # sampled multiminal id
            straight_through_onehot = tf.stop_gradient(
                next_samples_onehot - next_gumbel_probs) + next_gumbel_probs

            print(next_gumbel_probs.get_shape(), "=====gumbel====",
                  straight_through_onehot.get_shape())
            gumbel_mask = tf.expand_dims(tf.expand_dims(tf.one_hot(
                i, actual_length),
                                                        axis=0),
                                         axis=2)  # [1, seq, 1]
            gumbel_probs += tf.cast(gumbel_mask, tf.float32) * tf.expand_dims(
                straight_through_onehot, axis=1)  # b x 1 x vocab

            sample_mask = tf.expand_dims(tf.one_hot(i, actual_length),
                                         axis=(0))  # [1, seq, 1]
            print(sample_mask.get_shape(), "==sample mask shape==")
            print(samples.get_shape(), "==samples shape==")
            samples += tf.cast(sample_mask, tf.int32) * tf.cast(
                tf.expand_dims(next_samples, axis=-1), tf.int32)

            next_sample_logits = get_samples_logits(next_samples, next_logits)
            logits += tf.cast(sample_mask, tf.float32) * tf.expand_dims(
                next_sample_logits, axis=-1)
            input_mask += tf.cast(sample_mask, tf.int32) * tf.cast(
                tf.expand_dims(tf.ones_like(next_samples), axis=-1), tf.int32)

            return [
                i + 1, samples, gumbel_probs, input_mask, segment_ids, logits
            ]

        def gumbel_soft_body(i, samples, gumbel_probs, input_mask, segment_ids,
                             logits):
            next_outputs = step(i, samples, input_mask, segment_ids)

            logits_mask = tf.expand_dims(tf.one_hot(i - 1, actual_length),
                                         axis=(0))  # [1, seq]

            next_logits = tf.reduce_sum(
                next_outputs['logits'] *
                tf.cast(tf.expand_dims(logits_mask, axis=-1), tf.float32),
                axis=1)

            next_logits = next_logits / tf.to_float(temperature)

            # gumbel sample
            next_gumbel_probs, _ = gumbel_softmax(next_logits,
                                                  gumbel_temp,
                                                  gumbel_samples=None,
                                                  samples=1)
            next_samples = tf.cast(tf.argmax(next_gumbel_probs, axis=1),
                                   tf.int32)

            print(next_gumbel_probs.get_shape())
            gumbel_mask = tf.expand_dims(tf.expand_dims(tf.one_hot(
                i, actual_length),
                                                        axis=0),
                                         axis=2)  # [1, seq, 1]
            gumbel_probs += tf.cast(gumbel_mask, tf.float32) * tf.expand_dims(
                next_gumbel_probs, axis=1)  # b x 1 x vocab

            sample_mask = tf.expand_dims(tf.one_hot(i, actual_length),
                                         axis=(0))  # [1, seq]
            print(sample_mask.get_shape(), "==sample mask shape==")
            print(samples.get_shape(), "==samples shape==")
            samples += tf.cast(sample_mask, tf.int32) * tf.cast(
                tf.expand_dims(next_samples, axis=-1), tf.int32)

            next_sample_logits = get_samples_logits(next_samples, next_logits)
            logits += tf.cast(sample_mask, tf.float32) * tf.expand_dims(
                next_sample_logits, axis=-1)

            return [
                i + 1, samples, gumbel_probs, input_mask, segment_ids, logits
            ]

        init_i = tf.cast(
            bert_utils.get_shape_list(context, expected_rank=[2, 3])[1],
            tf.int32)

        if estimator == "straight_through":
            # final, samples, gumbel_probs, input_mask, segment_ids, logits = tf.while_loop(
            # 	cond=lambda i, _1, _2, _3, _4, _5: i < seq_length-1,
            # 	body=gumbel_st_body,
            # 	loop_vars=[init_i,
            # 		samples,
            # 		gumbel_probs,
            # 		input_mask,
            # 		segment_ids,
            # 		logits
            # 	],
            # 	back_prop=back_prop,
            # 	swap_memory=swap_memory,
            # 	maximum_iterations=seq_length
            # )

            for i in range(1, max_seq_length - 1):
                [
                    final, samples, gumbel_probs, input_mask, segment_ids,
                    logits
                ] = gumbel_st_body(i, samples, gumbel_probs, input_mask,
                                   segment_ids, logits)

        elif estimator == "soft":
            # final, samples, gumbel_probs, input_mask, segment_ids, logits = tf.while_loop(
            # 	cond=lambda i, _1, _2, _3, _4, _5: i < seq_length-1,
            # 	body=gumbel_soft_body,
            # 	loop_vars=[init_i,
            # 		samples,
            # 		gumbel_probs,
            # 		input_mask,
            # 		segment_ids,
            # 		logits
            # 	],
            # 	back_prop=back_prop,
            # 	swap_memory=swap_memory,
            # 	maximum_iterations=seq_length
            # )

            for i in range(1, max_seq_length - 1):
                [
                    final, samples, gumbel_probs, input_mask, segment_ids,
                    logits
                ] = gumbel_soft_body(i, samples, gumbel_probs, input_mask,
                                     segment_ids, logits)

        else:
            # final, samples, input_mask, segment_ids, logits = tf.while_loop(
            # 	cond=lambda i, _1, _2, _3, _4: i < seq_length-1,
            # 	body=body,
            # 	loop_vars=[init_i,
            # 		samples,
            # 		input_mask,
            # 		segment_ids,
            # 		logits
            # 	],
            # 	back_prop=back_prop,
            # 	swap_memory=swap_memory,
            # 	maximum_iterations=seq_length
            # )

            for i in range(1, max_seq_length - 1):
                [final, samples, input_mask, segment_ids,
                 logits] = body(i, samples, input_mask, segment_ids, logits)

        mask_sequence = get_finised_pos_v1(samples, end_token, actual_length)
        print(mask_sequence.get_shape(), "==mask shape==")
        samples *= tf.cast(mask_sequence, tf.int32)
        logits *= tf.cast(mask_sequence, tf.float32)
        if estimator in ["straight_through", "soft"]:
            gumbel_probs *= tf.expand_dims(tf.cast(mask_sequence, tf.float32),
                                           axis=-1)
            return {
                "samples": samples,
                "mask_sequence": mask_sequence,
                "gumbel_probs": gumbel_probs,
                "logits": logits,
                "input_mask": input_mask
            }
        else:
            return {
                "samples": samples,
                "mask_sequence": mask_sequence,
                "logits": logits,
                "input_mask": input_mask
            }
Ejemplo n.º 19
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        (nsp_loss, nsp_per_example_loss,
         nsp_log_prob) = pretrain.get_next_sentence_output(
             model_config,
             model.get_pooled_output(),
             features['next_sentence_labels'],
             reuse=tf.AUTO_REUSE,
             scope='generator')

        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]

        if model_config.model_type == 'bert':
            masked_lm_fn = pretrain.get_masked_lm_output
            print("==apply bert masked lm==")
        elif model_config.model_type == 'albert':
            masked_lm_fn = pretrain_albert.get_masked_lm_output
            print("==apply albert masked lm==")
        else:
            masked_lm_fn = pretrain.get_masked_lm_output
            print("==apply bert masked lm==")

        (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
         masked_lm_mask) = masked_lm_fn(
             model_config,
             model.get_sequence_output(),
             model.get_embedding_table(),
             masked_lm_positions,
             masked_lm_ids,
             masked_lm_weights,
             reuse=tf.AUTO_REUSE,
             embedding_projection=model.get_embedding_projection_table(),
             scope='generator')
        print(model_config.lm_ratio, '==mlm lm_ratio==')
        loss = model_config.lm_ratio * masked_lm_loss  #+ model_config.nsp_ratio * nsp_loss

        sampled_ids = token_generator(
            model_config,
            model.get_sequence_output(),
            model.get_embedding_table(),
            features['input_ids'],
            features['input_ori_ids'],
            features['input_mask'],
            embedding_projection=model.get_embedding_projection_table(),
            scope='generator',
            mask_method='only_mask')

        if model_config.get('gen_sample', 1) == 1:
            input_ids = features['input_ori_ids']
            input_mask = features['input_mask']
            segment_ids = features['segment_ids']
        else:
            input_ids = tf.expand_dims(features['input_ori_ids'], axis=-1)
            # batch x seq_length x 1
            input_ids = tf.einsum(
                'abc,cd->abd', input_ids,
                tf.ones((1, model_config.get('gen_sample', 1))))
            input_ids = tf.cast(input_ids, tf.int32)

            input_shape_list = bert_utils.get_shape_list(input_ids,
                                                         expected_rank=3)
            batch_size = input_shape_list[0]
            seq_length = input_shape_list[1]
            gen_sample = input_shape_list[2]

            sampled_ids = tf.reshape(sampled_ids,
                                     [batch * gen_sample, seq_length])
            input_ids = tf.reshape(input_ids, [batch * gen_sample, seq_length])

            input_mask = tf.expand_dims(features['input_mask'], axis=-1)
            input_mask = tf.einsum(
                'abc,cd->abd', input_mask,
                tf.ones((1, model_config.get('gen_sample', 1))))
            input_mask = tf.cast(input_mask, tf.int32)

            segment_ids = tf.expand_dims(features['segmnet_ids'], axis=-1)
            segment_ids = tf.einsum(
                'abc,cd->abd', segment_ids,
                tf.ones((1, model_config.get('gen_sample', 1))))
            segment_ids = tf.cast(segment_ids, tf.int32)

            segment_ids = tf.reshape(segment_ids,
                                     [batch * gen_sample, seq_length])
            input_mask = tf.reshape(input_mask,
                                    [batch * gen_sample, seq_length])

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        lm_pretrain_tvars = model_io_fn.get_params(
            "generator/cls/predictions", not_storage_params=not_storage_params)

        pretrained_tvars.extend(lm_pretrain_tvars)
        tvars = pretrained_tvars

        print('==generator parameters==', tvars)

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                tvars,
                init_checkpoint,
                exclude_scope="generator",
                use_tpu=use_tpu)
        else:
            scaffold_fn = None

        return_dict = {
            "loss": loss,
            "tvars": tvars,
            "model": model,
            "sampled_ids": sampled_ids,  # batch x gen_sample, seg_length
            "sampled_input_ids": input_ids,  # batch x gen_sample, seg_length,
            "sampled_input_mask": input_mask,
            "sampled_segment_ids": segment_ids,
            "masked_lm_positions": masked_lm_positions,
            "masked_lm_ids": masked_lm_ids,
            "masked_lm_weights": masked_lm_weights,
            "masked_lm_log_probs": masked_lm_log_probs,
            "masked_lm_example_loss": masked_lm_example_loss
        }
        return return_dict
Ejemplo n.º 20
0
def seq_mask_masked_lm_output(config, input_tensor, output_weights,
							input_mask, input_ori_ids, input_ids, 
							sampled_binary_mask, **kargs):

	input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
	batch_size = input_shape_list[0]
	seq_length = input_shape_list[1]
	hidden_dims = input_shape_list[2]

	embedding_projection = kargs.get('embedding_projection', None)

	scope = kargs.get('scope', None)
	if scope:
		scope = scope + '/' + 'cls/predictions'
	else:
		scope = 'cls/predictions'

	tf.logging.info("**** mlm generator scope **** %s", str(scope))

	# with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
	with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
		if config.get('ln_type', 'postln') == 'preln':
			input_tensor = albert_modules.layer_norm(input_tensor)
		elif config.get('ln_type', 'postln') == 'postln':
			input_tensor = input_tensor
		else:
			input_tensor = input_tensor

		if config.get("embedding", "factorized") == "factorized":
			projection_width = config.hidden_size
		else:
			projection_width = config.embedding_size

		with tf.variable_scope("transform"):
			input_tensor = tf.layers.dense(
					input_tensor,
					units=projection_width,
					activation=albert_modules.get_activation(config.hidden_act),
					kernel_initializer=albert_modules.create_initializer(
							config.initializer_range))

			if config.get('ln_type', 'postln') == 'preln':
				input_tensor = input_tensor
			elif config.get('ln_type', 'postln') == 'postln':
				input_tensor = albert_modules.layer_norm(input_tensor)
			else:
				input_tensor = albert_modules.layer_norm(input_tensor)

		if embedding_projection is not None:
			# batch x seq x hidden, embedding x hidden
			print(input_tensor.get_shape(), embedding_projection.get_shape())
			input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection)
		else:
			print("==no need for embedding projection==")
			input_tensor = input_tensor

		output_bias = tf.get_variable(
				"output_bias",
				shape=[config.vocab_size],
				initializer=tf.zeros_initializer())
		# batch x seq x embedding
		logits = tf.einsum("abc,dc->abd", input_tensor, output_weights)
		logits = tf.nn.bias_add(logits, output_bias)

		"""
		if input_ori_ids[i] is random pertubated, sampled_binary_mask[i]=1
		"""
		sampled_binary_mask = tf.cast(sampled_binary_mask, tf.float32)
		per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
												logits=logits,
												labels=tf.stop_gradient(input_ori_ids),
												)
		per_example_loss *= sampled_binary_mask
		loss = tf.reduce_sum(per_example_loss) / tf.reduce_sum(sampled_binary_mask)

		return (loss, per_example_loss, logits, sampled_binary_mask)
Ejemplo n.º 21
0
def iso_gaussian_sample(logits, temperature, samples=1):
    input_shape_list = bert_utils.get_shape_list(logits, expected_rank=2)
    if samples > 1:
        logits = tf.expand_dims(logits, -1)
    y = logits + sample_normal(input_shape_list, samples)
    return [tf.exp(tf.nn.log_softmax(y / temperature)), logits]
Ejemplo n.º 22
0
def distributed_transformer_model(input_tensor,
                                  attention_mask=None,
                                  hidden_size=768,
                                  num_hidden_layers=12,
                                  num_attention_heads=12,
                                  intermediate_size=3072,
                                  intermediate_act_fn=gelu,
                                  hidden_dropout_prob=0.1,
                                  attention_probs_dropout_prob=0.1,
                                  initializer_range=0.02,
                                  do_return_all_layers=False,
                                  gpu_nums=2):
    """Multi-headed, multi-layer Transformer from "Attention is All You Need".

	This is almost an exact implementation of the original Transformer encoder.

	See the original paper:
	https://arxiv.org/abs/1706.03762

	Also see:
	https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py

	Args:
		input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
		attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
			seq_length], with 1 for positions that can be attended to and 0 in
			positions that should not be.
		hidden_size: int. Hidden size of the Transformer.
		num_hidden_layers: int. Number of layers (blocks) in the Transformer.
		num_attention_heads: int. Number of attention heads in the Transformer.
		intermediate_size: int. The size of the "intermediate" (a.k.a., feed
			forward) layer.
		intermediate_act_fn: function. The non-linear activation function to apply
			to the output of the intermediate/feed-forward layer.
		hidden_dropout_prob: float. Dropout probability for the hidden layers.
		attention_probs_dropout_prob: float. Dropout probability of the attention
			probabilities.
		initializer_range: float. Range of the initializer (stddev of truncated
			normal).
		do_return_all_layers: Whether to also return all layers or just the final
			layer.

	Returns:
		float Tensor of shape [batch_size, seq_length, hidden_size], the final
		hidden layer of the Transformer.

	Raises:
		ValueError: A Tensor shape or parameter is invalid.
	"""
    if hidden_size % num_attention_heads != 0:
        raise ValueError(
            "The hidden size (%d) is not a multiple of the number of attention "
            "heads (%d)" % (hidden_size, num_attention_heads))

    attention_head_size = int(hidden_size / num_attention_heads)
    input_shape = bert_utils.get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape[0]
    seq_length = input_shape[1]
    input_width = input_shape[2]

    # The Transformer performs sum residuals on all layers so the input needs
    # to be the same as the hidden size.
    if input_width != hidden_size:
        raise ValueError(
            "The width of the input tensor (%d) != hidden size (%d)" %
            (input_width, hidden_size))

    # We keep the representation as a 2D tensor to avoid re-shaping it back and
    # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
    # the GPU/CPU but may not be free on the TPU, so we want to minimize them to
    # help the optimizer.
    prev_output = bert_utils.reshape_to_matrix(input_tensor)

    all_layer_outputs = []

    gpu_partition = int(num_hidden_layers / gpu_nums)

    gpu_id = -1  # gpu_id is started from 0 to gpu_nums

    for layer_idx in range(num_hidden_layers):
        with tf.variable_scope("layer_%d" % layer_idx):
            layer_input = prev_output

            if np.mod(layer_idx, gpu_partition) == 0:
                gpu_id += 1

            with tf.device('/gpu:{}'.format(gpu_id)):

                tf.logging.info(
                    " apply transformer attention {}-th layer on device {} ".
                    format(layer_idx, gpu_id))
                print(" apply transformer attention {}-th layer on device {} ".
                      format(layer_idx, gpu_id))

                with tf.variable_scope("attention"):
                    attention_heads = []
                    with tf.variable_scope("self"):
                        attention_head = attention_layer(
                            from_tensor=layer_input,
                            to_tensor=layer_input,
                            attention_mask=attention_mask,
                            num_attention_heads=num_attention_heads,
                            size_per_head=attention_head_size,
                            attention_probs_dropout_prob=
                            attention_probs_dropout_prob,
                            initializer_range=initializer_range,
                            do_return_2d_tensor=True,
                            batch_size=batch_size,
                            from_seq_length=seq_length,
                            to_seq_length=seq_length)
                        attention_heads.append(attention_head)

                    attention_output = None
                    if len(attention_heads) == 1:
                        attention_output = attention_heads[0]
                    else:
                        # In the case where we have other sequences, we just concatenate
                        # them to the self-attention head before the projection.
                        attention_output = tf.concat(attention_heads, axis=-1)

                    # Run a linear projection of `hidden_size` then add a residual
                    # with `layer_input`.
                    with tf.variable_scope("output"):
                        attention_output = tf.layers.dense(
                            attention_output,
                            hidden_size,
                            kernel_initializer=create_initializer(
                                initializer_range))
                        attention_output = dropout(attention_output,
                                                   hidden_dropout_prob)
                        attention_output = layer_norm(attention_output +
                                                      layer_input)

                # The activation is only applied to the "intermediate" hidden layer.
                with tf.variable_scope("intermediate"):
                    intermediate_output = tf.layers.dense(
                        attention_output,
                        intermediate_size,
                        activation=intermediate_act_fn,
                        kernel_initializer=create_initializer(
                            initializer_range))

                # Down-project back to `hidden_size` then add the residual.
                with tf.variable_scope("output"):
                    layer_output = tf.layers.dense(
                        intermediate_output,
                        hidden_size,
                        kernel_initializer=create_initializer(
                            initializer_range))
                    layer_output = dropout(layer_output, hidden_dropout_prob)
                    layer_output = layer_norm(layer_output + attention_output)
                    prev_output = layer_output
                    all_layer_outputs.append(layer_output)

    if do_return_all_layers:
        final_outputs = []
        for layer_output in all_layer_outputs:
            final_output = bert_utils.reshape_from_matrix(
                layer_output, input_shape)
            final_outputs.append(final_output)
        return final_outputs
    else:
        final_output = bert_utils.reshape_from_matrix(prev_output, input_shape)
        return final_output
Ejemplo n.º 23
0
def hmm_input_ids_generation(config,
							input_ori_ids,
							input_mask,
							hmm_tran_prob_list,
							**kargs):

	mask_id = kargs.get('mask_id', 103)

	input_ori_ids = tf.cast(input_ori_ids, tf.int32)
	input_mask = tf.cast(input_mask, tf.int32)

	unk_mask = tf.cast(tf.math.equal(input_ori_ids, 100), tf.float32) # not replace unk
	cls_mask =  tf.cast(tf.math.equal(input_ori_ids, 101), tf.float32) # not replace cls
	sep_mask = tf.cast(tf.math.equal(input_ori_ids, 102), tf.float32) # not replace sep

	none_replace_mask =  unk_mask + cls_mask + sep_mask
	mask_probability = kargs.get("mask_probability", 0.2)
	replace_probability = kargs.get("replace_probability", 0.1)
	original_probability = kargs.get("original_probability", 0.1)

	input_shape_list = bert_utils.get_shape_list(input_mask, expected_rank=2)
	batch_size = input_shape_list[0]
	seq_length = input_shape_list[1]
		
	tf.logging.info("**** apply fixed_mask_prob %s **** ", str(mask_probability))
	tf.logging.info("**** apply replace_probability %s **** ", str(replace_probability))
	tf.logging.info("**** apply original_probability %s **** ", str(original_probability))

	# state, sampled_binary_mask = dynamic_span_mask_v1(batch_size, seq_length, hmm_tran_prob_list[0])
	sampled_binary_mask = mask_method(batch_size, seq_length, hmm_tran_prob_list, **kargs)

	sampled_binary_mask = input_mask * (1 - tf.cast(none_replace_mask, tf.int32)) * sampled_binary_mask
	sampled_binary_mask = tf.cast(sampled_binary_mask, tf.float32)

	replace_binary_probs = replace_probability * (sampled_binary_mask) # use 10% [mask] to replace token
	replace_noise_dist = tf.distributions.Bernoulli(probs=replace_binary_probs, dtype=tf.float32)
	sampled_replace_binary_mask = replace_noise_dist.sample()
	sampled_replace_binary_mask = tf.cast(sampled_replace_binary_mask, tf.float32)

	ori_binary_probs = original_probability * (sampled_binary_mask - sampled_replace_binary_mask)
	ori_noise_dist = tf.distributions.Bernoulli(probs=ori_binary_probs, dtype=tf.float32)
	sampled_ori_binary_mask = ori_noise_dist.sample()
	sampled_ori_binary_mask = tf.cast(sampled_ori_binary_mask, tf.float32)

	sampled_mask_binary_mask = (sampled_binary_mask - sampled_replace_binary_mask - sampled_ori_binary_mask)
	sampled_mask_binary_mask = tf.cast(sampled_mask_binary_mask, tf.float32)

	vocab_sample_logits = tf.random.uniform(
							[batch_size, seq_length, config.vocab_size],
							minval=0.0,
							maxval=10.0,
							dtype=tf.float32)

	vocab_sample_logits = tf.nn.log_softmax(vocab_sample_logits)
	flatten_vocab_sample_logits = tf.reshape(vocab_sample_logits, 
											[batch_size*seq_length, -1])

	# sampled_logprob_temp, sampled_logprob = gumbel_softmax(flatten_vocab_sample_logits, 
	# 									temperature=0.1,
	# 									samples=config.get('gen_sample', 1))

	# sample_vocab_ids = tf.argmax(sampled_logprob, axis=1) # batch x seq

	sample_vocab_ids = tf.multinomial(flatten_vocab_sample_logits, 
								num_samples=config.get('gen_sample', 1), 
								output_dtype=tf.int32)

	sample_vocab_ids = tf.reshape(sample_vocab_ids, [batch_size, seq_length])
	sample_vocab_ids = tf.cast(sample_vocab_ids, tf.float32)
	input_ori_ids = tf.cast(input_ori_ids, tf.float32)

	output_input_ids = mask_id * tf.cast(sampled_mask_binary_mask, tf.float32) * tf.ones_like(input_ori_ids)
	output_input_ids += sample_vocab_ids * tf.cast(sampled_replace_binary_mask, tf.float32)
	output_input_ids += (1 - tf.cast(sampled_mask_binary_mask + sampled_replace_binary_mask, tf.float32)) * input_ori_ids
	output_sampled_binary_mask = sampled_mask_binary_mask + sampled_replace_binary_mask + sampled_ori_binary_mask

	print("===output_input_ids shape===", output_input_ids.get_shape())
	input_shape_list = bert_utils.get_shape_list(output_input_ids, expected_rank=2)
	print("==input shape list==", input_shape_list)

	output_sampled_binary_mask = tf.cast(output_sampled_binary_mask, tf.int32)
	if not kargs.get('use_tpu', True):
		tf.summary.scalar('mask_ratio', 
		tf.reduce_sum(tf.cast(output_sampled_binary_mask, tf.float32))/(1e-10+tf.cast(tf.reduce_sum(input_mask), dtype=tf.float32)))

	return [tf.cast(output_input_ids, tf.int32), 
				output_sampled_binary_mask]
Ejemplo n.º 24
0
    def model_fn(features, labels, mode):
        label_ids = features["label_ids"]
        model_lst = []
        for index, name in enumerate(input_name):
            if index > 0:
                reuse = True
            else:
                reuse = model_reuse
            model_lst.append(
                base_model(model_config,
                           features,
                           labels,
                           mode,
                           name,
                           reuse=reuse))

        if mode == tf.estimator.ModeKeys.TRAIN:
            hidden_dropout_prob = model_config.hidden_dropout_prob
            attention_probs_dropout_prob = model_config.attention_probs_dropout_prob
            dropout_prob = model_config.dropout_prob
        else:
            hidden_dropout_prob = 0.0
            attention_probs_dropout_prob = 0.0
            dropout_prob = 0.0

        assert len(model_lst) == len(input_name)

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):

            try:
                label_ratio_table = tf.get_variable(
                    name="label_ratio",
                    shape=[
                        num_labels,
                    ],
                    initializer=tf.constant(label_tensor),
                    trainable=False)

                ratio_weight = tf.nn.embedding_lookup(label_ratio_table,
                                                      label_ids)
            except:
                ratio_weight = None

            seq_output_lst = [model.get_pooled_output() for model in model_lst]
            repres = seq_output_lst[0] + seq_output_lst[1]

            final_hidden_shape = bert_utils.get_shape_list(repres,
                                                           expected_rank=2)

            z_mean = tf.layers.dense(repres,
                                     final_hidden_shape[1],
                                     name="z_mean")
            z_log_var = tf.layers.dense(repres,
                                        final_hidden_shape[1],
                                        name="z_log_var")
            print("=======applying vib============")
            if mode == tf.estimator.ModeKeys.TRAIN:
                print("====applying vib====")
                vib_connector = vib.VIB(vib_config)
                [kl_loss, latent_vector
                 ] = vib_connector.build_regularizer([z_mean, z_log_var])

                [loss, per_example_loss,
                 logits] = classifier.classifier(model_config, latent_vector,
                                                 num_labels, label_ids,
                                                 dropout_prob, ratio_weight)

                loss += tf.reduce_mean(kl_loss)
            else:
                print("====applying z_mean for prediction====")
                [loss, per_example_loss,
                 logits] = classifier.classifier(model_config, z_mean,
                                                 num_labels, label_ids,
                                                 dropout_prob, ratio_weight)

        # model_io_fn = model_io.ModelIO(model_io_config)
        pretrained_tvars = model_io_fn.get_params(model_config.scope)
        if load_pretrained:
            model_io_fn.load_pretrained(pretrained_tvars, init_checkpoint)

        tvars = model_io_fn.get_params(scope)

        if mode == tf.estimator.ModeKeys.TRAIN:
            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                optimizer_fn = optimizer.Optimizer(opt_config)
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps)

                return [train_op, loss, per_example_loss, logits]
        else:
            model_io_fn.print_params(tvars, string=", trainable params")
            return [loss, loss, per_example_loss, logits]
Ejemplo n.º 25
0
def sample_sequence(model_api,
                    model_config,
                    mode,
                    features,
                    target="",
                    start_token=101,
                    batch_size=None,
                    context=None,
                    temperature=1,
                    n_samples=1,
                    top_k=0,
                    end_token=102,
                    greedy_or_sample="sample",
                    gumbel_temp=0.01,
                    estimator="straight_through",
                    back_prop=True,
                    swap_memory=True,
                    **kargs):

    input_shape = bert_utils.get_shape_list(features["input_ids"],
                                            expected_rank=[2, 3])
    batch_size = input_shape[0]
    seq_length = input_shape[1]
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)
        print(context.get_shape(), "===init context shape===")

    context_shape = bert_utils.get_shape_list(context, expected_rank=[2])
    actual_length = seq_length

    # Scalar dimensions referenced here:
    #   B = batch size (number of sequences)
    #   F = `from_tensor` sequence length
    #   T = `to_tensor` sequence length
    #   N = `num_attention_heads`
    #   H = `size_per_head`

    attention_head_size = int(model_config.hidden_size /
                              model_config.num_attention_heads)

    # single layer present: [B, 2, N, T, H]
    # all layer present: [B, N_layer, 2, N, T, H]
    presents = tf.zeros(
        (batch_size, model_config.num_hidden_layers, 2,
         model_config.num_attention_heads, actual_length, attention_head_size))

    samples = tf.cast(tf.zeros((batch_size, actual_length)), tf.int32)
    end_mask = tf.expand_dims(tf.one_hot(actual_length - 1, actual_length),
                              axis=(0))
    samples += end_token * tf.cast(
        end_mask, tf.int32)  # make sure last token is end token

    #     samples += start_token * tf.einsum("ab,bc->ac",
    #                                     tf.cast(tf.ones((batch_size, tf.shape(start_mask)[0])), tf.int32),
    #                                      tf.cast(start_mask, tf.int32))

    start_mask = tf.one_hot(tf.range(0, context_shape[1]), actual_length)
    samples += tf.einsum("ab,bc->ac", context, tf.cast(start_mask, tf.int32))

    logits = tf.cast(tf.zeros((batch_size, actual_length)), tf.float32)

    #     start_mask = tf.expand_dims(tf.one_hot(0, seq_length+1), axis=(0))
    #     samples += start_token*tf.cast(start_mask, tf.int32) # make sure last token is end token

    if estimator in ["straight_through", "soft"]:
        gumbel_probs = tf.zeros((batch_size, actual_length - context_shape[1],
                                 model_config.vocab_size))

        start_probs = context
        start_one_hot = tf.one_hot(start_probs, model_config.vocab_size)
        gumbel_probs = tf.concat(
            [tf.cast(start_one_hot, tf.float32), gumbel_probs], axis=1)

    def step(step, tokens, segment_ids=None, past=None):

        token_shape = bert_utils.get_shape_list(tokens, expected_rank=[2, 3])

        features = {}
        features['input_ids'] = tokens
        if segment_ids is None:
            features['segment_ids'] = tf.cast(
                tf.zeros((token_shape[0], token_shape[1])), tf.int32)
        else:
            features['segment_ids'] = segment_ids
        if past is None:
            features['input_mask'] = tf.cast(
                tf.ones((token_shape[0], token_shape[1])), tf.int32)
            features['past'] = None
        else:
            past_shape = bert_utils.get_shape_list(past, expected_rank=[6])
            features['input_mask'] = tf.cast(
                tf.ones((past_shape[0], step + token_shape[1])), tf.int32)
            features['past'] = past[:, :, :, :, :(step), :]

        inference_model = model_api(model_config,
                                    features, [],
                                    mode,
                                    target,
                                    reuse=tf.AUTO_REUSE,
                                    **kargs)

        logits = inference_model.get_sequence_output_logits()
        next_presents = inference_model.get_present()

        next_presents_shape = bert_utils.get_shape_list(next_presents,
                                                        expected_rank=[6])
        print(presents.get_shape())
        if next_presents_shape[-2] > 0:
            print(next_presents_shape)
            print(next_presents.get_shape(), "===next presents shape===")
            #             mask = tf.expand_dims(tf.one_hot(step, seq_length+1), axis=(0, 1, 2, 3, 5))
            mask = tf.one_hot(tf.range(step, step + token_shape[1]),
                              actual_length)
            #             tf.expand_dims(tf.one_hot(tf.range(step, step+token_shape[1]), seq_length+1), axis=0)
            #             mask = tf.expand_dims(mask, axis=1)
            #             mask = tf.expand_dims(mask, axis=2)
            #             mask = tf.expand_dims(mask, axis=3)
            #             mask = tf.expand_dims(mask, axis=5)
            print(mask.get_shape(), "===mask shape===")

            past = tf.einsum("abcdef,eg->abcdgf", next_presents, mask) + past

#             past = past + tf.cast(mask, tf.float32) * next_presents

        return {
            'logits': logits,
            'presents': past,
        }

    with tf.name_scope('sample_sequence'):
        # Don't feed the last context token -- leave that to the loop below
        # TODO: Would be slightly faster if we called step on the entire context,
        # rather than leaving the last token transformer calculation to the while loop.

        print(context[:, :-1].get_shape())
        init_context_shape = bert_utils.get_shape_list(context[:, :-1],
                                                       expected_rank=[2, 3])
        init_segment_ids = tf.cast(
            tf.zeros((init_context_shape[0], init_context_shape[1])), tf.int32)
        context_output = step(0,
                              context[:, :-1],
                              segment_ids=init_segment_ids,
                              past=presents)

        def get_samples_logits(samples, logits):
            batch_idxs = tf.range(0, tf.shape(samples)[0])
            batch_idxs = tf.expand_dims(batch_idxs, 1)
            samples = tf.expand_dims(samples, 1)

            idxs = tf.concat([batch_idxs, samples], 1)
            sample_logits = tf.gather_nd(logits, idxs)
            return sample_logits

        def body(i, past, prev, samples, segment_ids, logits):
            print(prev.get_shape(), "==prev shape==")
            next_outputs = step(i - 1,
                                prev[:, tf.newaxis],
                                segment_ids=segment_ids,
                                past=past)
            next_logits = next_outputs['logits'][:, -1, :] / tf.to_float(
                temperature)
            next_logits = tf.nn.log_softmax(next_logits, axis=-1)
            if greedy_or_sample == "sample":
                next_samples = tf.multinomial(next_logits,
                                              num_samples=1,
                                              output_dtype=tf.int32)
                next_samples = tf.squeeze(next_samples, axis=-1)
            elif greedy_or_sample == "greedy":
                next_samples = tf.argmax(next_logits, axis=-1)
            else:
                next_samples = tf.argmax(next_logits, axis=-1)
            print(next_samples.get_shape(), "==sample shape==")

            print(tf.one_hot(i, seq_length + 1).get_shape(), "====shhhhape===")
            sample_mask = tf.expand_dims(tf.one_hot(i, actual_length),
                                         axis=(0))  # [1, seq, 1]
            print(sample_mask.get_shape(), "==sample mask shape==")
            print(samples.get_shape(), "==samples shape==")
            samples += tf.cast(sample_mask, tf.int32) * tf.cast(
                tf.expand_dims(next_samples, axis=-1), tf.int32)

            next_sample_logits = get_samples_logits(next_samples, next_logits)
            logits += tf.cast(sample_mask, tf.float32) * tf.expand_dims(
                next_sample_logits, axis=-1)

            return [
                i + 1, next_outputs['presents'], next_samples, samples,
                segment_ids, logits
            ]

        def gumbel_st_body(i, past, prev, samples, gumbel_probs, segment_ids,
                           logits):
            #             next_outputs = step(i-1, prev[:, tf.newaxis], past=past)
            #             gumbel_probs[:, i-1, :]
            next_outputs = step(i - 1,
                                tf.expand_dims(gumbel_probs[:, i - 1, :],
                                               axis=1),
                                segment_ids=segment_ids,
                                past=past)

            next_logits = next_outputs['logits'][:, -1, :] / tf.to_float(
                temperature)
            next_logits = tf.nn.log_softmax(next_logits, axis=-1)
            #             if greedy_or_sample == "sample":
            #                 next_samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
            #                 next_samples = tf.squeeze(next_samples, axis=-1)
            #             elif greedy_or_sample == "greedy":
            #                 next_samples = tf.argmax(logits, axis=-1, keepdims=True)
            #             else:
            #                 next_samples = tf.argmax(logits, axis=-1, keepdims=True)
            next_gumbel_probs, _ = gumbel_softmax(next_logits,
                                                  gumbel_temp,
                                                  gumbel_samples=None,
                                                  samples=1)
            next_samples = tf.cast(tf.argmax(next_gumbel_probs, axis=1),
                                   tf.int32)
            next_samples_onehot = tf.one_hot(next_samples,
                                             model_config.vocab_size,
                                             axis=1)  # sampled multiminal id
            straight_through_onehot = tf.stop_gradient(
                next_samples_onehot - next_gumbel_probs) + next_gumbel_probs

            print(next_gumbel_probs.get_shape(), "=====gumbel====",
                  straight_through_onehot.get_shape())
            gumbel_mask = tf.expand_dims(tf.expand_dims(tf.one_hot(
                i, actual_length),
                                                        axis=0),
                                         axis=2)  # [1, seq, 1]
            gumbel_probs += tf.cast(gumbel_mask, tf.float32) * tf.expand_dims(
                straight_through_onehot, axis=1)  # b x 1 x vocab

            sample_mask = tf.expand_dims(tf.one_hot(i, actual_length),
                                         axis=(0))  # [1, seq, 1]
            print(sample_mask.get_shape(), "==sample mask shape==")
            print(samples.get_shape(), "==samples shape==")
            samples += tf.cast(sample_mask, tf.int32) * tf.cast(
                tf.expand_dims(next_samples, axis=-1), tf.int32)

            next_sample_logits = get_samples_logits(next_samples, next_logits)
            logits += tf.cast(sample_mask, tf.float32) * tf.expand_dims(
                next_sample_logits, axis=-1)

            return [
                i + 1, next_outputs['presents'], next_samples, samples,
                gumbel_probs, segment_ids, logits
            ]

        def gumbel_soft_body(i, past, prev, samples, gumbel_probs, segment_ids,
                             logits):
            next_outputs = step(i - 1,
                                prev[:, tf.newaxis],
                                segment_ids=segment_ids,
                                past=past)
            # #             gumbel_probs[:, i-1, :]
            # 			next_outputs = step(i-1, tf.expand_dims(gumbel_probs[:, i-1, :], axis=1),
            # 								segment_ids=segment_ids,
            # 								past=past)

            next_logits = next_outputs['logits'][:, -1, :] / tf.to_float(
                temperature)
            next_logits = tf.nn.log_softmax(next_logits, axis=-1)
            #
            # gumbel sample
            next_gumbel_probs, _ = gumbel_softmax(next_logits,
                                                  gumbel_temp,
                                                  gumbel_samples=None,
                                                  samples=1)
            next_samples = tf.cast(tf.argmax(next_gumbel_probs, axis=1),
                                   tf.int32)
            next_samples_onehot = tf.one_hot(next_samples,
                                             model_config.vocab_size,
                                             axis=1)  # sampled multiminal id

            # straight-through token_matrix
            # straight_through_onehot = tf.stop_gradient(next_samples_onehot-next_gumbel_probs)+next_gumbel_probs

            print(next_gumbel_probs.get_shape())
            gumbel_mask = tf.expand_dims(tf.expand_dims(tf.one_hot(
                i, actual_length),
                                                        axis=0),
                                         axis=2)  # [1, seq, 1]
            gumbel_probs += tf.cast(gumbel_mask, tf.float32) * tf.expand_dims(
                next_gumbel_probs, axis=1)  # b x 1 x vocab

            sample_mask = tf.expand_dims(tf.one_hot(i, actual_length),
                                         axis=(0))  # [1, seq]
            print(sample_mask.get_shape(), "==sample mask shape==")
            print(samples.get_shape(), "==samples shape==")
            samples += tf.cast(sample_mask, tf.int32) * tf.cast(
                tf.expand_dims(next_samples, axis=-1), tf.int32)

            next_sample_logits = get_samples_logits(next_samples, next_logits)
            logits += tf.cast(sample_mask, tf.float32) * tf.expand_dims(
                next_sample_logits, axis=-1)

            return [
                i + 1, next_outputs['presents'], next_samples, samples,
                gumbel_probs, segment_ids, logits
            ]

        init_i = bert_utils.get_shape_list(context[:, :-1],
                                           expected_rank=[2, 3])[1] + 1
        if kargs.get("mask_type", "left2right") == 'left2right':
            left_segment_ids = tf.expand_dims(tf.cast(
                tf.zeros_like(context[:, -1]), tf.int32),
                                              axis=-1)
        elif kargs.get("mask_type", "left2right") == 'seq2seq':
            left_segment_ids = tf.expand_dims(tf.cast(
                tf.ones_like(context[:, -1]), tf.int32),
                                              axis=-1)

        if estimator == "straight_through":
            final, presents, _, samples, gumbel_probs, _, logits = tf.while_loop(
                cond=lambda i, _1, _2, _3, _4, _5, _6: i < seq_length - 1,
                body=gumbel_st_body,
                loop_vars=[
                    init_i,
                    context_output['presents'],
                    #                     presents,
                    context[:, -1],
                    samples,
                    gumbel_probs,
                    left_segment_ids,
                    logits
                ],
                back_prop=back_prop,
                swap_memory=swap_memory)

        elif estimator == "soft":
            final, presents, _, samples, gumbel_probs, _, logits = tf.while_loop(
                cond=lambda i, _1, _2, _3, _4, _5, _6: i < seq_length - 1,
                body=gumbel_soft_body,
                loop_vars=[
                    init_i,
                    context_output['presents'],
                    #                     presents,
                    context[:, -1],
                    samples,
                    gumbel_probs,
                    left_segment_ids,
                    logits
                ],
                back_prop=back_prop,
                swap_memory=swap_memory)

        else:
            final, presents, _, samples, _, logits = tf.while_loop(
                cond=lambda i, _1, _2, _3, _4, _5: i < seq_length - 1,
                body=body,
                loop_vars=[
                    init_i,
                    context_output['presents'],
                    #                     presents,
                    context[:, -1],
                    samples,
                    left_segment_ids,
                    logits
                ],
                back_prop=back_prop,
                swap_memory=swap_memory)

#         results = body(5, presents, context[:, -1], samples)
#         samples = results[-1]
#         print(samples)
        mask_sequence = get_finised_pos(samples, end_token, actual_length)
        #         print(mask_sequence.get_shape())
        #         samples *= tf.cast(mask_sequence, tf.int32)
        #         logits *= tf.cast(mask_sequence, tf.float32)
        if estimator in ["straight_through", "soft"]:
            gumbel_probs *= tf.expand_dims(tf.cast(mask_sequence, tf.float32),
                                           axis=-1)
            return samples, gumbel_probs, presents, logits, final
        else:
            return samples, mask_sequence, presents, logits, final
Ejemplo n.º 26
0
    def build_encoder(self,
                      input_ids,
                      input_mask,
                      hidden_dropout_prob,
                      attention_probs_dropout_prob,
                      past=None,
                      decode_loop_step=None,
                      max_decode_length=None,
                      if_bp=False,
                      if_cache_decode=None,
                      **kargs):
        reuse = kargs["reuse"]
        input_shape = bert_utils.get_shape_list(input_ids,
                                                expected_rank=[2, 3])
        batch_size = input_shape[0]
        seq_length = input_shape[1]

        if input_mask is None:
            input_mask = tf.ones(shape=[batch_size, seq_length],
                                 dtype=tf.int32)

        with tf.variable_scope(self.config.get("scope", "bert"), reuse=reuse):
            with tf.variable_scope("encoder"):
                # This converts a 2D mask of shape [batch_size, seq_length] to a 3D
                # mask of shape [batch_size, seq_length, seq_length] which is used
                # for the attention scores.

                input_shape = bert_utils.get_shape_list(input_ids,
                                                        expected_rank=[2, 3])
                if len(input_shape) == 3:
                    tmp_input_ids = tf.argmax(input_ids, axis=-1)
                else:
                    tmp_input_ids = input_ids

                if decode_loop_step is None:
                    self.bi_attention_mask = bert_seq_modules.create_attention_mask_from_input_mask(
                        tmp_input_ids, input_mask)
                else:
                    if max_decode_length is None:
                        max_decode_length = self.max_position_embeddings
                    # [max_decode_length, 1]
                    input_mask = tf.expand_dims(tf.sequence_mask(
                        decode_loop_step + 1, maxlen=max_decode_length),
                                                axis=-1)
                    # [1, max_decode_length]
                    input_mask = tf.transpose(input_mask, perm=[1, 0])
                    input_mask = tf.tile(input_mask, [batch_size, 1])
                    self.bi_attention_mask = bert_seq_modules.create_attention_mask_from_input_mask(
                        tmp_input_ids, input_mask)

                seq_type = kargs.get('seq_type', "None")
                print(seq_type)

                if seq_type == "seq2seq":
                    if kargs.get("mask_type", "left2right") == "left2right":
                        mask_sequence = None
                        tf.logging.info(
                            "==apply left2right LM model with casual mask==")
                    elif kargs.get("mask_type", "left2right") == "seq2seq":
                        token_type_ids = kargs.get("token_type_ids", None)
                        tf.logging.info(
                            "==apply left2right LM model with conditional casual mask=="
                        )
                        if token_type_ids is None:
                            token_type_ids = tf.zeros_like(input_mask)
                            tf.logging.info(
                                "==conditional mask is set to 0 and degenerate to left2right LM model=="
                            )
                        mask_sequence = token_type_ids
                    else:
                        mask_sequence = None
                    if decode_loop_step is None:
                        self.attention_mask = bert_utils.generate_seq2seq_mask(
                            self.bi_attention_mask, mask_sequence, seq_type)
                    else:
                        # with loop step, we must do casual decoding
                        self.attention_mask = bert_utils.generate_seq2seq_mask(
                            self.bi_attention_mask, None, seq_type)
                else:
                    tf.logging.info(
                        "==apply bi-directional LM model with bi-directional mask=="
                    )
                    self.attention_mask = self.bi_attention_mask

                # Run the stacked transformer.
                # `sequence_output` shape = [batch_size, seq_length, hidden_size].

                if kargs.get('attention_type',
                             'normal_attention') == 'normal_attention':
                    tf.logging.info("****** normal attention *******")
                    transformer_model = bert_seq_modules.transformer_model
                elif kargs.get('attention_type',
                               'normal_attention') == 'rezero_transformer':
                    transformer_model = bert_seq_modules.transformer_rezero_model
                    tf.logging.info("****** rezero_transformer *******")
                else:
                    tf.logging.info("****** normal attention *******")
                    transformer_model = bert_seq_modules.transformer_model

                [
                    self.all_encoder_layers, self.all_present,
                    self.all_attention_scores, self.all_value_outputs
                ] = transformer_model(
                    input_tensor=self.embedding_output,
                    attention_mask=self.attention_mask,
                    hidden_size=self.config.hidden_size,
                    num_hidden_layers=self.config.num_hidden_layers,
                    num_attention_heads=self.config.num_attention_heads,
                    intermediate_size=self.config.intermediate_size,
                    intermediate_act_fn=bert_seq_modules.get_activation(
                        self.config.hidden_act),
                    hidden_dropout_prob=hidden_dropout_prob,
                    attention_probs_dropout_prob=attention_probs_dropout_prob,
                    initializer_range=self.config.initializer_range,
                    do_return_all_layers=True,
                    past=past,
                    decode_loop_step=decode_loop_step,
                    if_bp=if_bp,
                    if_cache_decode=if_cache_decode,
                    attention_fixed_size=self.config.get(
                        'attention_fixed_size', None))
Ejemplo n.º 27
0
def multi_position_crf_classifier(config, features, model_dict, num_labels,
                                  dropout_prob):

    batch_size = features['batch_size']
    total_length_a = features['total_length_a']
    total_length_b = features['total_length_b']

    sequence_output_a = model_dict["a"].get_sequence_output(
    )  # [batch x 10, 130, 768]
    shape_lst = bert_utils.get_shape_list(sequence_output_a, expected_rank=3)

    sequence_output_a = tf.reshape(
        sequence_output_a,
        [-1, total_length_a, shape_lst[-1]])  # [batch, 10 x 130, 768]
    answer_pos = tf.cast(features['label_positions'], tf.int32)
    sequence_output_a = bert_utils.gather_indexes(
        sequence_output_a, answer_pos)  # [batch*10, 768]

    sequence_output_a = tf.reshape(
        sequence_output_a, [-1, config.max_predictions_per_seq, shape_lst[-1]
                            ])  # [batch, 10, 768]

    sequence_output_b = model_dict["b"].get_pooled_output()  # [batch x 10,768]
    sequence_output_b = tf.reshape(
        sequence_output_b, [-1, num_labels, shape_lst[-1]])  # [batch, 10, 768]
    seq_b_shape = bert_utils.get_shape_list(sequence_output_b, expected_rank=3)

    cross_matrix = tf.get_variable(
        "output_weights", [shape_lst[-1], shape_lst[-1]],
        initializer=tf.truncated_normal_initializer(stddev=0.02))

    # batch x 10 x 768
    sequence_output_a_proj = tf.einsum("abc,cd->abd", sequence_output_a,
                                       cross_matrix)

    # batch x 10 x 768. batch x 10 x 768
    # batch x 10(ans_pos) x 11(ans_field)
    logits = tf.einsum("abd,acd->abc", sequence_output_a_proj,
                       sequence_output_b)
    logits = tf.multiply(
        logits, 1.0 / tf.math.sqrt(tf.cast(shape_lst[-1], tf.float32)))

    # print(sequence_output_a.get_shape(), sequence_output_b.get_shape(), logits.get_shape())

    # label_ids = tf.cast(features['label_ids'], tf.int32)
    # label_weights = tf.cast(features['label_weights'], tf.int32)
    # label_seq_length = tf.reduce_sum(label_weights, axis=-1)

    # transition = zero_transition(seq_b_shape)

    # log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
    # 										inputs=logits,
    # 										tag_indices=label_ids,
    # 										sequence_lengths=label_seq_length,
    # 										transition_params=transition)

    # transition_params = tf.stop_gradient(transition_params)
    # per_example_loss = -log_likelihood
    # loss = tf.reduce_mean(per_example_loss)

    return (loss, per_example_loss, logits, transition_params)
Ejemplo n.º 28
0
    def build_embedder(self,
                       input_ids,
                       token_type_ids,
                       hidden_dropout_prob,
                       attention_probs_dropout_prob,
                       past=None,
                       decode_loop_step=None,
                       **kargs):

        reuse = kargs["reuse"]

        if self.config.get("embedding",
                           "none_factorized") == "none_factorized":
            projection_width = self.config.hidden_size
            tf.logging.info("==not using embedding factorized==")
        else:
            projection_width = self.config.get('embedding_size',
                                               self.config.hidden_size)
            tf.logging.info(
                "==using embedding factorized: embedding size: %s==",
                str(projection_width))

        if self.config.get('embedding_scope', None):
            embedding_scope = self.config['embedding_scope']
            other_embedding_scope = self.config[
                'embedding_scope']  #self.config.get("scope", "bert")
            tf.logging.info(
                "==using embedding scope of original model_config.embedding_scope: %s, other_embedding_scope:%s ==",
                embedding_scope, other_embedding_scope)
        else:
            embedding_scope = self.config.get("scope", "bert")
            other_embedding_scope = self.config.get("scope", "bert")
            tf.logging.info(
                "==using embedding scope of original model_config.embedding_scope: %s, other_embedding_scope:%s ==",
                embedding_scope, other_embedding_scope)
        if past is None:
            self.past_length = 0
        else:
            # batch_size_, num_layers_, two_, num_heads_, self.cache_length, features_
            if decode_loop_step is None:
                # gpu-decode length
                past_shape = bert_utils.get_shape_list(past, expected_rank=[6])
                self.past_length = past_shape[-2]
            else:
                self.past_length = decode_loop_step

        with tf.variable_scope(embedding_scope, reuse=reuse):
            with tf.variable_scope("embeddings"):
                # Perform embedding lookup on the word ids.
                # (self.embedding_output_word, self.embedding_table) = bert_modules.embedding_lookup(
                # 		input_ids=input_ids,
                # 		vocab_size=self.config.vocab_size,
                # 		embedding_size=projection_width,
                # 		initializer_range=self.config.initializer_range,
                # 		word_embedding_name="word_embeddings",
                # 		use_one_hot_embeddings=self.config.use_one_hot_embeddings)

                input_shape = bert_utils.get_shape_list(input_ids,
                                                        expected_rank=[2, 3])
                print(input_shape, "=====input_shape=====")
                if len(input_shape) == 3:
                    tf.logging.info("****** 3D embedding matmul *******")
                    (self.embedding_output_word, self.embedding_table
                     ) = bert_modules.gumbel_embedding_lookup(
                         input_ids=input_ids,
                         vocab_size=self.config.vocab_size,
                         embedding_size=projection_width,
                         initializer_range=self.config.initializer_range,
                         word_embedding_name="word_embeddings",
                         use_one_hot_embeddings=self.config.
                         use_one_hot_embeddings)
                elif len(input_shape) == 2:
                    (self.embedding_output_word,
                     self.embedding_table) = bert_modules.embedding_lookup(
                         input_ids=input_ids,
                         vocab_size=self.config.vocab_size,
                         embedding_size=projection_width,
                         initializer_range=self.config.initializer_range,
                         word_embedding_name="word_embeddings",
                         use_one_hot_embeddings=self.config.
                         use_one_hot_embeddings)
                else:
                    (self.embedding_output_word,
                     self.embedding_table) = bert_modules.embedding_lookup(
                         input_ids=input_ids,
                         vocab_size=self.config.vocab_size,
                         embedding_size=projection_width,
                         initializer_range=self.config.initializer_range,
                         word_embedding_name="word_embeddings",
                         use_one_hot_embeddings=self.config.
                         use_one_hot_embeddings)

                if kargs.get("perturbation", None):
                    self.embedding_output_word += kargs["perturbation"]
                    tf.logging.info(
                        " add word pertubation for robust learning ")

        with tf.variable_scope(other_embedding_scope, reuse=reuse):
            with tf.variable_scope("embeddings"):

                # Add positional embeddings and token type embeddings, then layer
                # normalize and perform dropout.
                tf.logging.info("==using segment type embedding ratio: %s==",
                                str(self.config.get("token_type_ratio", 1.0)))
                self.embedding_output = bert_seq_modules.embedding_postprocessor(
                    input_tensor=self.embedding_output_word,
                    use_token_type=kargs.get('use_token_type', True),
                    token_type_ids=token_type_ids,
                    token_type_vocab_size=self.config.type_vocab_size,
                    token_type_embedding_name="token_type_embeddings",
                    use_position_embeddings=True,
                    position_embedding_name="position_embeddings",
                    initializer_range=self.config.initializer_range,
                    max_position_embeddings=self.config.
                    max_position_embeddings,
                    dropout_prob=hidden_dropout_prob,
                    token_type_ratio=self.config.get("token_type_ratio", 1.0),
                    position_offset=self.past_length)
Ejemplo n.º 29
0
def attention_layer(from_tensor,
										to_tensor,
										attention_mask=None,
										num_attention_heads=1,
										size_per_head=512,
										query_act=None,
										key_act=None,
										value_act=None,
										attention_probs_dropout_prob=0.0,
										initializer_range=0.02,
										do_return_2d_tensor=False,
										batch_size=None,
										from_seq_length=None,
										to_seq_length=None):
	"""Performs multi-headed attention from `from_tensor` to `to_tensor`.

	This is an implementation of multi-headed attention based on "Attention
	is all you Need". If `from_tensor` and `to_tensor` are the same, then
	this is self-attention. Each timestep in `from_tensor` attends to the
	corresponding sequence in `to_tensor`, and returns a fixed-with vector.

	This function first projects `from_tensor` into a "query" tensor and
	`to_tensor` into "key" and "value" tensors. These are (effectively) a list
	of tensors of length `num_attention_heads`, where each tensor is of shape
	[batch_size, seq_length, size_per_head].

	Then, the query and key tensors are dot-producted and scaled. These are
	softmaxed to obtain attention probabilities. The value tensors are then
	interpolated by these probabilities, then concatenated back to a single
	tensor and returned.

	In practice, the multi-headed attention are done with transposes and
	reshapes rather than actual separate tensors.

	Args:
		from_tensor: float Tensor of shape [batch_size, from_seq_length,
			from_width].
		to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
		attention_mask: (optional) int32 Tensor of shape [batch_size,
			from_seq_length, to_seq_length]. The values should be 1 or 0. The
			attention scores will effectively be set to -infinity for any positions in
			the mask that are 0, and will be unchanged for positions that are 1.
		num_attention_heads: int. Number of attention heads.
		size_per_head: int. Size of each attention head.
		query_act: (optional) Activation function for the query transform.
		key_act: (optional) Activation function for the key transform.
		value_act: (optional) Activation function for the value transform.
		attention_probs_dropout_prob:
		initializer_range: float. Range of the weight initializer.
		do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
			* from_seq_length, num_attention_heads * size_per_head]. If False, the
			output will be of shape [batch_size, from_seq_length, num_attention_heads
			* size_per_head].
		batch_size: (Optional) int. If the input is 2D, this might be the batch size
			of the 3D version of the `from_tensor` and `to_tensor`.
		from_seq_length: (Optional) If the input is 2D, this might be the seq length
			of the 3D version of the `from_tensor`.
		to_seq_length: (Optional) If the input is 2D, this might be the seq length
			of the 3D version of the `to_tensor`.

	Returns:
		float Tensor of shape [batch_size, from_seq_length,
			num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
			true, this will be of shape [batch_size * from_seq_length,
			num_attention_heads * size_per_head]).

	Raises:
		ValueError: Any of the arguments or tensor shapes are invalid.
	"""

	def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
													 seq_length, width):
		output_tensor = tf.reshape(
				input_tensor, [batch_size, seq_length, num_attention_heads, width])

		output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
		return output_tensor

	from_shape = bert_utils.get_shape_list(from_tensor, expected_rank=[2, 3])
	to_shape = bert_utils.get_shape_list(to_tensor, expected_rank=[2, 3])

	if len(from_shape) != len(to_shape):
		raise ValueError(
				"The rank of `from_tensor` must match the rank of `to_tensor`.")

	if len(from_shape) == 3:
		batch_size = from_shape[0]
		from_seq_length = from_shape[1]
		to_seq_length = to_shape[1]
	elif len(from_shape) == 2:
		if (batch_size is None or from_seq_length is None or to_seq_length is None):
			raise ValueError(
					"When passing in rank 2 tensors to attention_layer, the values "
					"for `batch_size`, `from_seq_length`, and `to_seq_length` "
					"must all be specified.")

	# Scalar dimensions referenced here:
	#   B = batch size (number of sequences)
	#   F = `from_tensor` sequence length
	#   T = `to_tensor` sequence length
	#   N = `num_attention_heads`
	#   H = `size_per_head`

	from_tensor_2d = bert_utils.reshape_to_matrix(from_tensor)
	to_tensor_2d = bert_utils.reshape_to_matrix(to_tensor)

	# `query_layer` = [B*F, N*H]
	query_layer = tf.layers.dense(
			from_tensor_2d,
			num_attention_heads * size_per_head,
			activation=query_act,
			name="query",
			kernel_initializer=create_initializer(initializer_range))

	# `key_layer` = [B*T, N*H]
	key_layer = tf.layers.dense(
			to_tensor_2d,
			num_attention_heads * size_per_head,
			activation=key_act,
			name="key",
			kernel_initializer=create_initializer(initializer_range))

	# `value_layer` = [B*T, N*H]
	value_layer = tf.layers.dense(
			to_tensor_2d,
			num_attention_heads * size_per_head,
			activation=value_act,
			name="value",
			kernel_initializer=create_initializer(initializer_range))

	# `query_layer` = [B, N, F, H]
	query_layer = transpose_for_scores(query_layer, batch_size,
																		 num_attention_heads, from_seq_length,
																		 size_per_head)

	# `key_layer` = [B, N, T, H]
	key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
																	 to_seq_length, size_per_head)

	# Take the dot product between "query" and "key" to get the raw
	# attention scores.
	# `attention_scores` = [B, N, F, T]
	attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
	attention_scores = tf.multiply(attention_scores,
																 1.0 / math.sqrt(float(size_per_head)))

	if attention_mask is not None:
		# `attention_mask` = [B, 1, F, T]
		attention_mask = tf.expand_dims(attention_mask, axis=[1])

		# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
		# masked positions, this operation will create a tensor which is 0.0 for
		# positions we want to attend and -10000.0 for masked positions.
		adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0

		# Since we are adding it to the raw scores before the softmax, this is
		# effectively the same as removing these entirely.
		attention_scores += adder

	# Normalize the attention scores to probabilities.
	# `attention_probs` = [B, N, F, T]
	# attention_probs = tf.nn.softmax(attention_scores)
	attention_probs = tf.exp(tf.nn.log_softmax(attention_scores))

	# This is actually dropping out entire tokens to attend to, which might
	# seem a bit unusual, but is taken from the original Transformer paper.
	attention_probs = dropout(attention_probs, attention_probs_dropout_prob)

	# `value_layer` = [B, T, N, H]
	value_layer = tf.reshape(
			value_layer,
			[batch_size, to_seq_length, num_attention_heads, size_per_head])

	# `value_layer` = [B, N, T, H]
	value_layer = tf.transpose(value_layer, [0, 2, 1, 3])

	# `context_layer` = [B, N, F, H]
	context_layer = tf.matmul(attention_probs, value_layer)

	# `context_layer` = [B, F, N, H]
	context_layer = tf.transpose(context_layer, [0, 2, 1, 3])

	if do_return_2d_tensor:
		# `context_layer` = [B*F, N*V]
		context_layer = tf.reshape(
				context_layer,
				[batch_size * from_seq_length, num_attention_heads * size_per_head])
	else:
		# `context_layer` = [B, F, N*V]
		context_layer = tf.reshape(
				context_layer,
				[batch_size, from_seq_length, num_attention_heads * size_per_head])

	return context_layer
Ejemplo n.º 30
0
def hidden_cls_matching(teacher_hidden, student_hidden, match_direction=0):

    teacher_shape = bert_utils.get_shape_list(teacher_hidden[0],
                                              expected_rank=[3])
    student_shape = bert_utils.get_shape_list(student_hidden[0],
                                              expected_rank=[3])

    if match_direction == 0:

        with tf.variable_scope("attention_weights", reuse=tf.AUTO_REUSE):
            projection_weights = tf.get_variable(
                "attention_score_weights",
                [len(student_hidden), len(teacher_hidden)],
                initializer=tf.constant_initializer(np.ones(
                    (len(student_hidden), len(teacher_hidden))) /
                                                    len(teacher_hidden),
                                                    dtype=tf.float32))
            normalized_weights = tf.abs(projection_weights) / tf.reduce_sum(
                tf.abs(projection_weights), axis=-1, keepdims=True)

    else:
        print("===apply teacher model to student model==")
        with tf.variable_scope("attention_weights", reuse=tf.AUTO_REUSE):
            projection_weights = tf.get_variable(
                "attention_score_weights",
                [len(student_hidden), len(teacher_hidden)],
                initializer=tf.constant_initializer(np.ones(
                    (len(student_hidden), len(teacher_hidden))) /
                                                    len(student_hidden),
                                                    dtype=tf.float32))
            normalized_weights = tf.abs(projection_weights) / tf.reduce_sum(
                tf.abs(projection_weights), axis=0, keepdims=True)

    # B X F X H

    def projection_fn(input_tensor):

        with tf.variable_scope("uniformal_mapping/projection",
                               reuse=tf.AUTO_REUSE):
            projection_weights = tf.get_variable(
                "output_weights", [student_shape[-1], teacher_shape[-1]],
                initializer=tf.truncated_normal_initializer(stddev=0.02))

            input_tensor_projection = tf.einsum("ac,cd->ad", input_tensor,
                                                projection_weights)
            return input_tensor_projection

    loss = tf.constant(0.0)
    for i in range(len(student_hidden)):
        student_hidden_ = student_hidden[i][:, 0:1, :]
        student_hidden_ = tf.squeeze(student_hidden_, axis=1)
        student_hidden_ = projection_fn(student_hidden_)
        student_hidden_ = tf.nn.l2_normalize(student_hidden_, axis=-1)
        for j in range(len(teacher_hidden)):
            teacher_hidden_ = teacher_hidden[j][:, 0:1, :]
            teacher_hidden_ = tf.squeeze(teacher_hidden_, axis=1)
            teacher_hidden_ = tf.nn.l2_normalize(teacher_hidden_, axis=-1)
            weight = normalized_weights[i, j]  # normalized to [0,1]
            tmp_loss = weight * l1_distance(
                student_hidden_, teacher_hidden_, axis=-1)
            loss += tf.reduce_mean(tmp_loss, axis=0)
    loss /= (len(student_hidden) * len(teacher_hidden))
    return loss