Example #1
0
def get_masked_lm_output(config,
                         input_tensor,
                         output_weights,
                         positions,
                         label_ids,
                         label_weights,
                         reuse=None):
    """Get loss and log probs for the masked LM."""
    input_tensor = tf.cast(input_tensor, tf.float32)
    positions = tf.cast(positions, tf.int32)
    label_ids = tf.cast(label_ids, tf.int32)
    label_weights = tf.cast(label_weights, tf.float32)

    input_tensor = bert_utils.gather_indexes(input_tensor, positions)
    """
	flatten masked lm ids with positions
	"""
    with tf.variable_scope("cls/predictions", reuse=reuse):
        # We apply one more non-linear transformation before the output layer.
        # This matrix is not used after pre-training.
        with tf.variable_scope("transform"):
            input_tensor = tf.layers.dense(
                input_tensor,
                units=config.hidden_size,
                activation=bert_modules.get_activation(config.hidden_act),
                kernel_initializer=bert_modules.create_initializer(
                    config.initializer_range))
            input_tensor = bert_modules.layer_norm(input_tensor)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        output_bias = tf.get_variable("output_bias",
                                      shape=[config.vocab_size],
                                      initializer=tf.zeros_initializer())
        logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)
        log_probs = tf.nn.log_softmax(logits, axis=-1)

        label_ids = tf.reshape(label_ids, [-1])
        label_weights = tf.reshape(label_weights, [-1])

        # one_hot_labels = tf.one_hot(
        # 		label_ids, depth=config.vocab_size, dtype=tf.float32)

        per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=label_ids, logits=logits)

        numerator = tf.reduce_sum(label_weights * per_example_loss)
        denominator = tf.reduce_sum(label_weights) + 1e-5

        # The `positions` tensor might be zero-padded (if the sequence is too
        # short to have the maximum number of predictions). The `label_weights`
        # tensor has a value of 1.0 for every real prediction and 0.0 for the
        # padding predictions.
        # per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
        # numerator = tf.reduce_sum(label_weights * per_example_loss)
        # denominator = tf.reduce_sum(label_weights) + 1e-5
        loss = numerator / denominator

    return (loss, per_example_loss, log_probs)
Example #2
0
    def build_output_logits(self, **kargs):
        layer_num = kargs.get("layer_num", -1)
        self.sequence_output = self.get_encoder_layers(layer_num)
        input_shape_list = bert_utils.get_shape_list(self.sequence_output,
                                                     expected_rank=3)
        batch_size = input_shape_list[0]
        seq_length = input_shape_list[1]
        hidden_dims = input_shape_list[2]

        embedding_projection = kargs.get('embedding_projection', None)

        scope = kargs.get('scope', None)
        if scope:
            scope = scope + '/' + 'cls/predictions'
        else:
            scope = 'cls/predictions'

        tf.logging.info("**** mlm generator scope **** %s", str(scope))

        # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            if self.config.get('ln_type', 'postln') == 'preln':
                input_tensor = bert_modules.layer_norm(self.sequence_output)
                tf.logging.info("**** pre ln doing layer norm ****")
            elif self.config.get('ln_type', 'postln') == 'postln':
                input_tensor = self.sequence_output
                tf.logging.info("**** post ln ****")
            else:
                input_tensor = self.sequence_output
                tf.logging.info("**** post ln ****")

            # if config.get("embedding", "factorized") == "factorized":
            # 	projection_width = config.hidden_size
            # else:
            # 	projection_width = config.embedding_size

            if self.config.get("embedding",
                               "none_factorized") == "none_factorized":
                projection_width = self.config.hidden_size
                tf.logging.info("==not using embedding factorized==")
            else:
                projection_width = self.config.get('embedding_size',
                                                   self.config.hidden_size)
                tf.logging.info(
                    "==using embedding factorized: embedding size: %s==",
                    str(projection_width))

            with tf.variable_scope("transform"):
                input_tensor = tf.layers.dense(
                    input_tensor,
                    units=projection_width,
                    activation=bert_modules.get_activation(
                        self.config.hidden_act),
                    kernel_initializer=bert_modules.create_initializer(
                        self.config.initializer_range))

                if self.config.get('ln_type', 'postln') == 'preln':
                    input_tensor = input_tensor
                    tf.logging.info("**** pre ln ****")
                elif self.config.get('ln_type', 'postln') == 'postln':
                    input_tensor = bert_modules.layer_norm(input_tensor)
                    tf.logging.info("**** post ln doing layer norm ****")
                else:
                    input_tensor = bert_modules.layer_norm(input_tensor)
                    tf.logging.info("**** post ln doing layer norm ****")

            if embedding_projection is not None:
                # batch x seq x hidden, embedding x hidden
                print(input_tensor.get_shape(),
                      embedding_projection.get_shape())
                input_tensor = tf.einsum("abc,dc->abd", input_tensor,
                                         embedding_projection)
            else:
                print("==no need for embedding projection==")
                input_tensor = input_tensor

            output_bias = tf.get_variable("output_bias",
                                          shape=[self.config.vocab_size],
                                          initializer=tf.zeros_initializer())
            # batch x seq x embedding
            logits = tf.einsum("abc,dc->abd", input_tensor,
                               self.embedding_table)
            self.logits = tf.nn.bias_add(logits, output_bias)
Example #3
0
def multi_position_classifier(config, features, sequence_output, num_labels,
                              dropout_prob):

    final_hidden_shape = bert_utils.get_shape_list(sequence_output,
                                                   expected_rank=3)

    print(final_hidden_shape, "====multi-choice shape====")

    answer_pos = tf.cast(features['label_positions'], tf.int32)
    cls_pos = tf.zeros_like(answer_pos)
    input_tensor = bert_utils.gather_indexes(sequence_output, answer_pos)
    cls_tensor = bert_utils.gather_indexes(sequence_output, cls_pos)

    answer_cls_tensor = tf.concat([cls_tensor, input_tensor], axis=-1)

    input_tensor = tf.layers.dense(
        answer_cls_tensor,
        units=config.hidden_size,
        activation=bert_modules.get_activation(config.hidden_act),
        kernel_initializer=bert_modules.create_initializer(
            config.initializer_range))
    input_tensor = bert_modules.layer_norm(input_tensor)

    output_weights = tf.get_variable(
        "output_weights", [num_labels, final_hidden_shape[-1]],
        initializer=tf.truncated_normal_initializer(stddev=0.02))

    output_bias = tf.get_variable("output_bias",
                                  shape=[num_labels],
                                  initializer=tf.zeros_initializer())
    logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)

    label_ids = tf.reshape(tf.cast(features['label_ids'], tf.int32), [-1])
    label_weights = tf.reshape(tf.cast(features['label_weights'], tf.float32),
                               [-1])

    if config.get('class_weights', None):
        class_weights = tf.constant(
            np.array(config.class_weights).astype(np.float32))

    if config.get("loss", "entropy") == "focal_loss":
        per_example_loss, _ = loss_utils.focal_loss_multi_v1(
            config, logits=logits, labels=tf.stop_gradient(label_ids))
    elif config.get("loss", "smoothed_ce") == 'smoothed_ce':
        per_example_loss = loss_utils.ce_label_smoothing(
            config, logits=logits, labels=tf.stop_gradient(label_ids))
    elif config.get('loss', 'class_balanced_focal') == 'class_balanced_focal':
        per_example_loss, _ = loss_utils.class_balanced_focal_loss_multi_v1(
            config,
            logits=logits,
            labels=label_ids,
            label_weights=class_weights)
    else:
        per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.stop_gradient(label_ids), logits=logits)

    numerator = tf.reduce_sum(label_weights * per_example_loss)
    denominator = tf.reduce_sum(label_weights) + 1e-5
    loss = numerator / denominator

    return (loss, per_example_loss, logits)
Example #4
0
def get_masked_lm_output(config, input_tensor, output_weights, positions,
							label_ids, label_weights, **kargs):
	"""Get loss and log probs for the masked LM."""
	reuse = kargs.get('reuse', False)
	input_tensor = tf.cast(input_tensor, tf.float32)
	positions = tf.cast(positions, tf.int32)
	label_ids = tf.cast(label_ids, tf.int32)
	label_weights = tf.cast(label_weights, tf.float32)

	scope = kargs.get('scope', None)
	if scope:
		scope = scope + '/' + 'cls/predictions'
	else:
		scope = 'cls/predictions'

	tf.logging.info("**** mlm scope **** %s", str(scope))

	# if config.get("embedding", "factorized") == "factorized":
	# 	projection_width = config.hidden_size
	# else:
	# 	projection_width = config.embedding_size

	if config.get("embedding", "none_factorized") == "none_factorized":
		projection_width = config.hidden_size
		tf.logging.info("==not using embedding factorized==")
	else:
		projection_width = config.get('embedding_size', config.hidden_size)
		tf.logging.info("==using embedding factorized: embedding size: %s==", str(projection_width))

	input_tensor = bert_utils.gather_indexes(input_tensor, positions)
	"""
	flatten masked lm ids with positions
	"""
	# with tf.variable_scope("cls/predictions", reuse=reuse):
	with tf.variable_scope(scope, reuse=reuse):
		# We apply one more non-linear transformation before the output layer.
		# This matrix is not used after pre-training.
		with tf.variable_scope("transform"):
			input_tensor = tf.layers.dense(
					input_tensor,
					units=projection_width,
					activation=bert_modules.get_activation(config.hidden_act),
					kernel_initializer=bert_modules.create_initializer(
							config.initializer_range))
			input_tensor = bert_modules.layer_norm(input_tensor)

		# The output weights are the same as the input embeddings, but there is
		# an output-only bias for each token.
		output_bias = tf.get_variable(
				"output_bias",
				shape=[config.vocab_size],
				initializer=tf.zeros_initializer())
		logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
		logits = tf.nn.bias_add(logits, output_bias)
		log_probs = tf.nn.log_softmax(logits, axis=-1)

		label_ids = tf.reshape(label_ids, [-1])
		label_weights = tf.reshape(label_weights, [-1])

		one_hot_labels = tf.one_hot(
				label_ids, depth=config.vocab_size, dtype=tf.float32)

		per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
													labels=tf.stop_gradient(label_ids),
													logits=logits)
		# per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])

		numerator = tf.reduce_sum(label_weights * per_example_loss)
		denominator = tf.reduce_sum(label_weights) + 1e-5

		# The `positions` tensor might be zero-padded (if the sequence is too
		# short to have the maximum number of predictions). The `label_weights`
		# tensor has a value of 1.0 for every real prediction and 0.0 for the
		# padding predictions.
		# per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
		# numerator = tf.reduce_sum(label_weights * per_example_loss)
		# denominator = tf.reduce_sum(label_weights) + 1e-5
		loss = numerator / denominator

	return (loss, per_example_loss, log_probs, label_weights)
Example #5
0
def seq_mask_masked_lm_output(config, input_tensor, output_weights,
							input_mask, input_ori_ids, input_ids, 
							sampled_binary_mask, **kargs):

	input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
	batch_size = input_shape_list[0]
	seq_length = input_shape_list[1]
	hidden_dims = input_shape_list[2]

	embedding_projection = kargs.get('embedding_projection', None)

	scope = kargs.get('scope', None)
	if scope:
		scope = scope + '/' + 'cls/predictions'
	else:
		scope = 'cls/predictions'

	tf.logging.info("**** mlm generator scope **** %s", str(scope))

	# with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
	with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
		if config.get('ln_type', 'postln') == 'preln':
			input_tensor = bert_modules.layer_norm(input_tensor)
		elif config.get('ln_type', 'postln') == 'postln':
			input_tensor = input_tensor
		else:
			input_tensor = input_tensor

		# if config.get("embedding", "factorized") == "factorized":
		# 	projection_width = config.hidden_size
		# else:
		# 	projection_width = config.embedding_size

		if config.get("embedding", "none_factorized") == "none_factorized":
			projection_width = config.hidden_size
			tf.logging.info("==not using embedding factorized==")
		else:
			projection_width = config.get('embedding_size', config.hidden_size)
			tf.logging.info("==using embedding factorized: embedding size: %s==", str(projection_width))

		with tf.variable_scope("transform"):
			input_tensor = tf.layers.dense(
					input_tensor,
					units=projection_width,
					activation=bert_modules.get_activation(config.hidden_act),
					kernel_initializer=bert_modules.create_initializer(
							config.initializer_range))

			if config.get('ln_type', 'postln') == 'preln':
				input_tensor = input_tensor
			elif config.get('ln_type', 'postln') == 'postln':
				input_tensor = bert_modules.layer_norm(input_tensor)
			else:
				input_tensor = bert_modules.layer_norm(input_tensor)

		if embedding_projection is not None:
			# batch x seq x hidden, embedding x hidden
			print(input_tensor.get_shape(), embedding_projection.get_shape())
			input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection)
		else:
			print("==no need for embedding projection==")
			input_tensor = input_tensor

		output_bias = tf.get_variable(
				"output_bias",
				shape=[config.vocab_size],
				initializer=tf.zeros_initializer())
		# batch x seq x embedding
		logits = tf.einsum("abc,dc->abd", input_tensor, output_weights)
		logits = tf.nn.bias_add(logits, output_bias)

		"""
		if input_ori_ids[i] is random pertubated, sampled_binary_mask[i]=1
		"""
		sampled_binary_mask = tf.cast(sampled_binary_mask, tf.float32)
		input_mask = tf.cast(input_mask, tf.float32)

		sampled_binary_mask *= input_mask

		per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
												logits=logits,
												labels=tf.stop_gradient(input_ori_ids),
												)
		per_example_loss *= sampled_binary_mask
		loss = tf.reduce_sum(per_example_loss) / (1e-10 + tf.reduce_sum(sampled_binary_mask))

		return (loss, per_example_loss, logits, sampled_binary_mask)
Example #6
0
def emb_score(config, input_tensor, input_ids, 
				output_weights,
				input_mask, **kargs):

	input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3)
	batch_size = input_shape_list[0]
	seq_length = input_shape_list[1]
	hidden_dims = input_shape_list[2]

	scope = kargs.get('scope', None)
	if scope:
		lm_scope = scope + '/' + 'cls/predictions'
	else:
		lm_scope = 'cls/predictions'

	tf.logging.info("**** mlm generator scope **** %s", str(lm_scope))

	# with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
	with tf.variable_scope(lm_scope, reuse=tf.AUTO_REUSE):
		if config.get('ln_type', 'postln') == 'preln':
			input_tensor = bert_modules.layer_norm(input_tensor)
		elif config.get('ln_type', 'postln') == 'postln':
			input_tensor = input_tensor
		else:
			input_tensor = input_tensor

		if config.get("embedding", "none_factorized") == "none_factorized":
			projection_width = config.hidden_size
			tf.logging.info("==not using embedding factorized==")
		else:
			projection_width = config.get('embedding_size', config.hidden_size)
			tf.logging.info("==using embedding factorized: embedding size: %s==", str(projection_width))

		if kargs.get("energy_pooling", "mi") == "mi":
			with tf.variable_scope("transform"):
				input_tensor = tf.layers.dense(
						input_tensor,
						units=projection_width,
						activation=bert_modules.get_activation(config.hidden_act),
						kernel_initializer=bert_modules.create_initializer(
								config.initializer_range))

				if config.get('ln_type', 'postln') == 'preln':
					input_tensor = input_tensor
				elif config.get('ln_type', 'postln') == 'postln':
					input_tensor = bert_modules.layer_norm(input_tensor)
				else:
					input_tensor = bert_modules.layer_norm(input_tensor)
			output_bias = tf.get_variable(
				"output_bias",
				shape=[config.vocab_size],
				initializer=tf.zeros_initializer())
			tf.logging.info("****** mi using mlm transform *******")
		elif kargs.get("energy_pooling", "mi") == "cls":
			with tf.variable_scope("transform_ebm"):
				# We "pool" the model by simply taking the hidden state corresponding
				# to the first token. We assume that this has been pre-trained
				first_token_tensor = tf.squeeze(input_tensor[:, 0:1, :], axis=1)
				input_tensor = tf.layers.dense(
						first_token_tensor,
						config.hidden_size,
						activation=tf.tanh, #bert_modules.get_activation(config.hidden_act),
						kernel_initializer=bert_modules.create_initializer(config.initializer_range))
				tf.logging.info("****** using cls pooling *******")
		else:
			with tf.variable_scope("transform_ebm"):
				input_tensor = tf.layers.dense(
				  input_tensor,
				  units=projection_width,
				  activation=tf.tanh, #bert_modules.get_activation(config.hidden_act),
				  kernel_initializer=bert_modules.create_initializer(
					  config.initializer_range))
			tf.logging.info("****** using other pooling transform *******")

	# with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
	if scope:
		ebm_scope = scope + '/' + 'ebm/predictions'
	else:
		ebm_scope = 'ebm/predictions'
	
	tf.logging.info("**** ebm generator scope **** %s", str(ebm_scope))

	print(input_tensor.get_shape(), "==input_tensor shape==")

	with tf.variable_scope(ebm_scope, reuse=tf.AUTO_REUSE):
		# assume the whole model is self-normalization

		if kargs.get("normalized_constant", "constant") == 'zero_constant':
			normalized_constant = tf.get_variable(
					"ebm_normalized_constant",
					shape=[config.max_position_embeddings],
					initializer=tf.zeros_initializer())

			valid_seq_length = tf.cast(tf.reduce_sum(input_mask, axis=-1), tf.int32) # batch_size
			onehot_length_ids = tf.one_hot(valid_seq_length, config.max_position_embeddings)

			input_normalized_constant = tf.einsum("ab,b->a", tf.cast(onehot_length_ids, tf.float32), normalized_constant)

			tf.logging.info("****** zero_constant logz *******")
		elif kargs.get("normalized_constant", "constant") == 'one_constant':
			normalized_constant = tf.get_variable(
					"ebm_normalized_constant",
					shape=[config.max_position_embeddings],
					initializer=tf.ones_initializer())
			tf.logging.info("****** one_constant logz *******")
			valid_seq_length = tf.cast(tf.reduce_sum(input_mask, axis=-1), tf.int32) # batch_size
			onehot_length_ids = tf.one_hot(valid_seq_length, config.max_position_embeddings)

			input_normalized_constant = tf.einsum("ab,b->a", tf.cast(onehot_length_ids, tf.float32), normalized_constant)

		elif kargs.get("normalized_constant", "constant") == 'constant_constant':
			normalized_constant = tf.get_variable(
					"ebm_normalized_constant",
					shape=[config.max_position_embeddings],
					initializer=tf.constant_initializer(np.ones((config.max_position_embeddings))*200.0, tf.float32))
			tf.logging.info("****** one_constant logz *******")
			valid_seq_length = tf.cast(tf.reduce_sum(input_mask, axis=-1), tf.int32) # batch_size
			onehot_length_ids = tf.one_hot(valid_seq_length, config.max_position_embeddings)

			input_normalized_constant = tf.einsum("ab,b->a", tf.cast(onehot_length_ids, tf.float32), normalized_constant)

		elif kargs.get("normalized_constant", "constant") == 'log9_constant':
			normalized_constant = tf.get_variable(
					"ebm_normalized_constant",
					shape=[config.max_position_embeddings],
					initializer=tf.constant_initializer(np.ones((config.max_position_embeddings))*np.log(9.0), tf.float32))
			tf.logging.info("****** one_constant logz *******")
			valid_seq_length = tf.cast(tf.reduce_sum(input_mask, axis=-1), tf.int32) # batch_size
			onehot_length_ids = tf.one_hot(valid_seq_length, config.max_position_embeddings)

			input_normalized_constant = tf.einsum("ab,b->a", tf.cast(onehot_length_ids, tf.float32), normalized_constant)

		elif kargs.get("normalized_constant", "constant") == 'logv_constant':
			normalized_constant = tf.get_variable(
					"ebm_normalized_constant",
					shape=[config.max_position_embeddings],
					initializer=tf.constant_initializer(np.ones((config.max_position_embeddings))*np.log(config.vocab_size), tf.float32))
			tf.logging.info("****** one_constant logz *******")
			valid_seq_length = tf.cast(tf.reduce_sum(input_mask, axis=-1), tf.int32) # batch_size
			onehot_length_ids = tf.one_hot(valid_seq_length, config.max_position_embeddings)

			input_normalized_constant = tf.einsum("ab,b->a", tf.cast(onehot_length_ids, tf.float32), normalized_constant)

		elif kargs.get("normalized_constant", "constant") == 'logv_constant_ln':
			normalized_constant = tf.get_variable(
					"ebm_normalized_constant",
					shape=[],
					initializer=tf.constant_initializer(np.log(config.vocab_size), tf.float32))

			input_normalized_constant = normalized_constant

		elif kargs.get("normalized_constant", "length_linear") == 'length_linear':
			normalized_constant = tf.get_variable(
					"ebm_normalized_constant",
					shape=[config.max_position_embeddings],
					initializer=tf.constant_initializer(np.arange((config.max_position_embeddings))+1, tf.float32),
					trainable=False)
			scale_weights = tf.get_variable(
					"ebm_normalized_constant_scale",
					shape=[config.max_position_embeddings],
					initializer=tf.constant_initializer(np.log(config.vocab_size)*np.ones((config.max_position_embeddings)), dtype=tf.float32),
					trainable=True)
			scale_bias = tf.get_variable(
					"ebm_normalized_constant_bias",
					shape=[config.max_position_embeddings],
					initializer=tf.zeros_initializer(),
					trainable=True)
			tf.logging.info("****** length linear logz *******")
			# normalized_constant = scale_bias + scale_weights * tf.pow(normalized_constant, 2)

			valid_seq_length = tf.cast(tf.reduce_sum(input_mask, axis=-1), tf.int32) # batch_size
			onehot_length_ids = tf.one_hot(valid_seq_length, config.max_position_embeddings)
			
			length_part = tf.einsum("ab,b->a", tf.cast(onehot_length_ids, tf.float32), normalized_constant)
			length_scale_part = tf.einsum("ab,b->a", tf.cast(onehot_length_ids, tf.float32), scale_weights)
			length_bias_part = tf.einsum("ab,b->a", tf.cast(onehot_length_ids, tf.float32), scale_bias)

			input_normalized_constant = length_part*length_scale_part + length_bias_part

		# input_normalized_constant = tf.einsum("ab,b->a", tf.cast(onehot_length_ids, tf.float32), normalized_constant)

		# f_input_mask = tf.cast(tf.expand_dims(input_mask, axis=-1), tf.float32)

		if kargs.get("energy_pooling", "mi") == "mean_pooling":
			tf.logging.info("==apply mean pooling to get hidden states projections==")
			# for input token sequence: <start> a b c
			# we only calculate energy on a,b,c which <start> can't contribute to final 
			# energy function
			# batch x dim
			pool_features = tf.einsum("abc,ab->ac", input_tensor[:, 1:], tf.cast(input_mask[:, 1:], tf.float32))
			pool_features /= (1e-10+tf.reduce_sum(tf.cast(input_mask[:, 1:], tf.float32), axis=1, keepdims=True))
			# tf.reduce_sum(input_tensor*f_input_mask, axis=1) #/ (1e-10+tf.reduce_sum(f_input_mask, axis=1))

			print(pool_features.get_shape(), "===pool_features shape===")
		elif kargs.get("energy_pooling", "mi") == "mi":
			tf.logging.info("==apply mi to get hidden states projections==")
			# input_tensor_norm = tf.expand_dims(tf.sqrt(tf.reduce_sum(tf.pow(input_tensor, 2), axis=-1))+1e-20, axis=-1)
			# input_tensor = input_tensor / tf.stop_gradient(input_tensor_norm)
			# output_weights_norm = tf.expand_dims(tf.sqrt(tf.reduce_sum(tf.pow(output_weights, 2), axis=-1))+1e-20, axis=-1)
			# output_weights = output_weights / tf.stop_gradient(output_weights_norm)
			# we calculate cosine distance to make mi bounded by [-1, 1]
			logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) # batch x seq x vocab
			logits = tf.nn.bias_add(logits, output_bias)

			input_id_shape = bert_utils.get_shape_list(input_ids, [2,3])
			if len(input_id_shape) == 2:
				onehot_input_ids = tf.cast(tf.one_hot(tf.cast(input_ids, tf.int32), config.vocab_size), tf.float32) # batch x seq x vocab
				input_ori_ids = tf.cast(onehot_input_ids, tf.float32)
				print("==input ori ids shape== 2-dim", input_ori_ids.get_shape())
			else:
				input_ori_ids = tf.cast(input_ids, tf.float32)
				print("==input ori ids shape== 3-dim", input_ori_ids.get_shape())

			logits = tf.einsum("abd,abd->ab", logits, input_ori_ids)
			print(logits.get_shape(), "==pooled logits shape==")
			# with l2-normalize, we can bound logits to 1
			pool_features = tf.reduce_sum(logits[:, 1:]*tf.cast(input_mask[:, 1:], tf.float32), axis=1) #/ (1e-10+tf.reduce_sum(tf.cast(input_mask[:, 1:], tf.float32), axis=1))
			pool_features = tf.expand_dims(pool_features, axis=-1)
			print(pool_features.get_shape(), "==pooled feature shape==")

			if kargs.get("softplus_features", False):
				# when pooled_features is to infinite, it converges to 0
				# when is to minus inifinite, it will converges to inifite
				pool_features = tf.nn.softplus(-pool_features)
				tf.logging.info("****** apply softplus transformation for pooled_features *******")

		elif kargs.get("energy_pooling", "mi") == "cls":
			with tf.variable_scope("transform"):
				pool_features = tf.layers.dense(
						input_tensor,
						units=1,
						use_bias=False,
						activation=None
						)
			tf.logging.info("****** apply linear transformation for pooled_features *******")
		# batch_size x hidden_dims

		if kargs.get('transform', True):

			if kargs.get("transformer_activation", "none") == 'softplus':
				with tf.variable_scope("transform"):
					ebm_scalar = tf.layers.dense(
							pool_features,
							units=1,
							use_bias=True,
							activation=tf.nn.softplus # mask scalar to [0,inifite]
							)
				tf.logging.info("****** apply softplus *******")
			elif kargs.get("transformer_activation", "none") == 'linear':
				tf.logging.info("****** apply linear projection *******")
				with tf.variable_scope("transform"):
					ebm_scalar = tf.layers.dense(
							pool_features,
							units=1,
							use_bias=True,
							activation=None # mask scalar to [0,inifite]
							)
			else:
				with tf.variable_scope("transform"):

					feature_shape = bert_utils.get_shape_list(pool_features, expected_rank=[1,2])

					pool_features = tf.layers.dense(
							pool_features,
							units=feature_shape[-1],
							activation=tf.nn.relu,
							)

					output_weights = tf.get_variable(
							"output_weights", [config.max_position_embeddings, feature_shape[-1]],
							initializer=tf.truncated_normal_initializer(stddev=0.02))

					output_bias = tf.get_variable(
							"output_bias", [config.max_position_embeddings], 
							initializer=tf.constant_initializer(-np.log(np.arange(config.max_position_embeddings).astype(np.float32)+1.0), dtype=tf.float32)
							)
				
					# batch x max_position_embeddings
					ebm_scalar_pos = tf.nn.relu(tf.matmul(pool_features, output_weights, transpose_b=True)) + output_bias
					
					pos_tensor = tf.cast(tf.reduce_sum(tf.cast(input_mask, tf.float32), axis=-1), tf.int32)
					onehot_pos = tf.cast(tf.one_hot(tf.cast(pos_tensor, tf.int32), config.max_position_embeddings), tf.float32) # batch x seq x vocab
					ebm_scalar = tf.einsum("ab,ab->a", ebm_scalar_pos, onehot_pos)
					ebm_scalar = tf.expand_dims(ebm_scalar, axis=-1)

				tf.logging.info("****** apply linear projection *******")
			print("===ebm_scalar====", ebm_scalar.get_shape())

			ebm_scalar = tf.squeeze(ebm_scalar, axis=-1)
			print("===ebm_scalar====", ebm_scalar.get_shape())
			# ebm_scalar /= (1e-10+tf.reduce_sum(tf.cast(input_mask, tf.float32), axis=-1))
			
			# if kargs.get("energy_pooling", "mi") == "mean_pooling":
			
			print("===ebm_scalar====", ebm_scalar.get_shape())
			print("===input_normalized_constant====", input_normalized_constant.get_shape())

		else:
			ebm_scalar = tf.squeeze(pool_features, axis=-1)
			# ebm_scalar /= (1e-10+tf.reduce_sum(tf.cast(input_mask, tf.float32), axis=-1))
			print("===ebm_scalar====", ebm_scalar.get_shape())
			print("===input_normalized_constant====", input_normalized_constant.get_shape())

		if not kargs.get("prob_ln", False):
			tf.logging.info("****** sum of plogprob as sentence probability *******")
			# ebm_scalar /= (1e-10+tf.reduce_sum(tf.cast(input_mask, tf.float32), axis=-1))
		else:
			ebm_scalar /= (1e-10+tf.reduce_sum(tf.cast(input_mask[:, 1:], tf.float32), axis=-1))
			tf.logging.info("****** sum of plogprob with length normalization as sentence probability *******")
		print("===ebm_scalar====", ebm_scalar.get_shape())
		print("===input_normalized_constant====", input_normalized_constant.get_shape())

		# original ebm log-likelihood:
		# log(exp(-E(x))/Z) = -E(x) - log(Z)
		# here we use bert encoder of pooled hidden states as energy function which need to minus when apply to 
		# actual energy function

		if not kargs.get("use_tpu", False):
			tf.summary.scalar('ebm_scalar', 
							tf.reduce_mean(ebm_scalar))

		if kargs.get("logz_mode", "default") == 'default':
			tf.logging.info("****** default logz *******")
			logits = -ebm_scalar - input_normalized_constant - tf.log(1e-10+tf.reduce_sum(tf.cast(input_mask, tf.float32), axis=-1))
		elif kargs.get("logz_mode", "default") == 'standard':
			logits = ebm_scalar - input_normalized_constant
			tf.logging.info("****** standard logz *******")
		elif kargs.get("logz_mode", "default") == 'standard_minus':
			tf.logging.info("****** minus standard logz *******")
			logits = -ebm_scalar - input_normalized_constant
		elif kargs.get("logz_mode", "default") == 'constant':
			logits = -ebm_scalar - tf.log(1e-10+tf.reduce_sum(tf.cast(input_mask, tf.float32), axis=-1))
			tf.logging.info("****** constant logz *******")
		elif kargs.get("logz_mode", "self_normalizing") == 'self_normalizing':
			logits = -ebm_scalar
			tf.logging.info("****** self_normalizing *******")
		elif kargs.get("logz_mode", "none") == 'none':
			logits = ebm_scalar
			tf.logging.info("****** none logz *******")
		else:
			tf.logging.info("****** linear logz *******")
			logits = ebm_scalar - input_normalized_constant * tf.reduce_sum(tf.cast(input_mask, tf.float32), axis=-1)
		print("=ebm logits shape==", logits.get_shape())

	return logits