예제 #1
0
파일: generator.py 프로젝트: Beleiaya/BERT
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        if kargs.get('random_generator', '1') == '1':
            if mode in [
                    tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.TRAIN
            ]:
                input_ori_ids = features['input_ori_ids']

                # [output_ids,
                # sampled_binary_mask] = random_input_ids_generation(model_config,
                # 							features['input_ori_ids'],
                # 							features['input_mask'],
                # 							mask_probability=0.2,
                # 							replace_probability=0.1,
                # 							original_probability=0.1,
                # 							**kargs)

                [output_ids, sampled_binary_mask] = hmm_input_ids_generation(
                    model_config,
                    features['input_ori_ids'],
                    features['input_mask'], [
                        tf.cast(tf.constant(hmm_tran_prob), tf.float32)
                        for hmm_tran_prob in hmm_tran_prob_list
                    ],
                    mask_probability=0.2,
                    replace_probability=0.0,
                    original_probability=0.0,
                    mask_prior=tf.constant(mask_prior, tf.float32),
                    **kargs)

                features['input_ids'] = output_ids
                tf.logging.info("****** do random generator *******")
            else:
                sampled_binary_mask = None
        else:
            sampled_binary_mask = None

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        (nsp_loss, nsp_per_example_loss,
         nsp_log_prob) = pretrain.get_next_sentence_output(
             model_config,
             model.get_pooled_output(),
             features['next_sentence_labels'],
             reuse=tf.AUTO_REUSE,
             scope=generator_scope_prefix)

        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]

        if model_config.model_type == 'bert':
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output
            print("==apply bert masked lm==")
        elif model_config.model_type == 'albert':
            masked_lm_fn = pretrain_albert.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply albert masked lm==")
        else:
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply bert masked lm==")

        if sampled_binary_mask is not None:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = seq_masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 features['input_mask'],
                 features['input_ori_ids'],
                 features['input_ids'],
                 sampled_binary_mask,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table(),
                 scope=generator_scope_prefix)
            masked_lm_ids = features['input_ori_ids']
        else:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 masked_lm_positions,
                 masked_lm_ids,
                 masked_lm_weights,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table(),
                 scope=generator_scope_prefix)
        print(model_config.lm_ratio, '==mlm lm_ratio==')
        loss = model_config.lm_ratio * masked_lm_loss + 0.0 * nsp_loss

        if kargs.get("resample_discriminator", False):
            input_ori_ids = features['input_ori_ids']

            [output_ids, sampled_binary_mask
             ] = random_input_ids_generation(model_config,
                                             features['input_ori_ids'],
                                             features['input_mask'],
                                             mask_probability=0.2,
                                             replace_probability=0.1,
                                             original_probability=0.1)

            resample_features = {}
            for key in features:
                resample_features[key] = features[key]

            resample_features['input_ids'] = tf.identity(output_ids)
            model_resample = model_api(model_config,
                                       resample_features,
                                       labels,
                                       mode,
                                       target,
                                       reuse=tf.AUTO_REUSE,
                                       **kargs)

            tf.logging.info("**** apply discriminator resample **** ")
        else:
            model_resample = model
            resample_features = features
            tf.logging.info("**** not apply discriminator resample **** ")

        sampled_ids = token_generator(model_config,
                                      model_resample.get_sequence_output(),
                                      model_resample.get_embedding_table(),
                                      resample_features['input_ids'],
                                      resample_features['input_ori_ids'],
                                      resample_features['input_mask'],
                                      embedding_projection=model_resample.
                                      get_embedding_projection_table(),
                                      scope=generator_scope_prefix,
                                      mask_method='only_mask',
                                      use_tpu=kargs.get('use_tpu', True))

        if model_config.get('gen_sample', 1) == 1:
            input_ids = features['input_ori_ids']
            input_mask = features['input_mask']
            segment_ids = features['segment_ids']
        else:
            input_ids = tf.expand_dims(features['input_ori_ids'], axis=-1)
            # batch x seq_length x 1
            input_ids = tf.einsum(
                'abc,cd->abd', input_ids,
                tf.ones((1, model_config.get('gen_sample', 1))))
            input_ids = tf.cast(input_ids, tf.int32)

            input_shape_list = bert_utils.get_shape_list(input_ids,
                                                         expected_rank=3)
            batch_size = input_shape_list[0]
            seq_length = input_shape_list[1]
            gen_sample = input_shape_list[2]

            sampled_ids = tf.reshape(sampled_ids,
                                     [batch * gen_sample, seq_length])
            input_ids = tf.reshape(input_ids, [batch * gen_sample, seq_length])

            input_mask = tf.expand_dims(features['input_mask'], axis=-1)
            input_mask = tf.einsum(
                'abc,cd->abd', input_mask,
                tf.ones((1, model_config.get('gen_sample', 1))))
            input_mask = tf.cast(input_mask, tf.int32)

            segment_ids = tf.expand_dims(features['segmnet_ids'], axis=-1)
            segment_ids = tf.einsum(
                'abc,cd->abd', segment_ids,
                tf.ones((1, model_config.get('gen_sample', 1))))
            segment_ids = tf.cast(segment_ids, tf.int32)

            segment_ids = tf.reshape(segment_ids,
                                     [batch * gen_sample, seq_length])
            input_mask = tf.reshape(input_mask,
                                    [batch * gen_sample, seq_length])

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        if generator_scope_prefix:
            """
			"generator/cls/predictions"
			"""
            lm_pretrain_tvars = model_io_fn.get_params(
                generator_scope_prefix + "/cls/predictions",
                not_storage_params=not_storage_params)

            nsp_pretrain_vars = model_io_fn.get_params(
                generator_scope_prefix + "/cls/seq_relationship",
                not_storage_params=not_storage_params)
        else:
            lm_pretrain_tvars = model_io_fn.get_params(
                "cls/predictions", not_storage_params=not_storage_params)

            nsp_pretrain_vars = model_io_fn.get_params(
                "cls/seq_relationship", not_storage_params=not_storage_params)

        if model_config.get('embedding_scope', None) is not None:
            embedding_tvars = model_io_fn.get_params(
                model_config.get('embedding_scope', 'bert') + "/embeddings",
                not_storage_params=not_storage_params)
            pretrained_tvars.extend(embedding_tvars)

        pretrained_tvars.extend(lm_pretrain_tvars)
        pretrained_tvars.extend(nsp_pretrain_vars)
        tvars = pretrained_tvars

        print('==generator parameters==', tvars)

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu,
                restore_var_name=model_config.get('restore_var_name', []))
        else:
            scaffold_fn = None
        tf.add_to_collection("generator_loss", loss)
        return_dict = {
            "loss": loss,
            "tvars": tvars,
            "model": model,
            "sampled_ids": sampled_ids,  # batch x gen_sample, seg_length
            "sampled_input_ids": input_ids,  # batch x gen_sample, seg_length,
            "sampled_input_mask": input_mask,
            "sampled_segment_ids": segment_ids,
            "masked_lm_ids": masked_lm_ids,
            "masked_lm_weights": masked_lm_mask,
            "masked_lm_log_probs": masked_lm_log_probs,
            "masked_lm_example_loss": masked_lm_example_loss,
            "next_sentence_example_loss": nsp_per_example_loss,
            "next_sentence_log_probs": nsp_log_prob,
            "next_sentence_labels": features['next_sentence_labels'],
            "sampled_binary_mask": sampled_binary_mask
        }
        return return_dict
예제 #2
0
	def model_fn(features, labels, mode, params):

		model_api = model_zoo(model_config)

		model = model_api(model_config, features, labels,
							mode, target, reuse=tf.AUTO_REUSE)

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope
		
		(nsp_loss, 
		 nsp_per_example_loss, 
		 nsp_log_prob) = pretrain.get_next_sentence_output(model_config,
										model.get_pooled_output(),
										features['next_sentence_labels'],
										reuse=tf.AUTO_REUSE)

		masked_lm_positions = features["masked_lm_positions"]
		masked_lm_ids = features["masked_lm_ids"]
		masked_lm_weights = features["masked_lm_weights"]

		if model_config.model_type == 'bert':
			masked_lm_fn = pretrain.get_masked_lm_output
			print("==apply bert masked lm==")
		elif model_config.model_type == 'albert':
			masked_lm_fn = pretrain_albert.get_masked_lm_output
			print("==apply albert masked lm==")
		else:
			masked_lm_fn = pretrain.get_masked_lm_output
			print("==apply bert masked lm==")

		(masked_lm_loss,
		masked_lm_example_loss, 
		masked_lm_log_probs,
		masked_lm_mask) = masked_lm_fn(
										model_config, 
										model.get_sequence_output(), 
										model.get_embedding_table(),
										masked_lm_positions, 
										masked_lm_ids, 
										masked_lm_weights,
										reuse=tf.AUTO_REUSE,
										embedding_projection=model.get_embedding_projection_table())
		print(model_config.lm_ratio, '==mlm lm_ratio==')
		loss = model_config.lm_ratio * masked_lm_loss #+ model_config.nsp_ratio * nsp_loss
		
		model_io_fn = model_io.ModelIO(model_io_config)

		pretrained_tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)

		lm_pretrain_tvars = model_io_fn.get_params("cls/predictions", 
									not_storage_params=not_storage_params)

		pretrained_tvars.extend(lm_pretrain_tvars)

		if load_pretrained == "yes":
			scaffold_fn = model_io_fn.load_pretrained(pretrained_tvars, 
											init_checkpoint,
											exclude_scope=exclude_scope,
											use_tpu=1)
		else:
			scaffold_fn = None
                print("******* scaffold fn *******", scaffold_fn)
		if mode == tf.estimator.ModeKeys.TRAIN:
						
			optimizer_fn = optimizer.Optimizer(opt_config)
						
			tvars = pretrained_tvars
			model_io_fn.print_params(tvars, string=", trainable params")
			
			# update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			# with tf.control_dependencies(update_ops):
			print('==gpu count==', opt_config.get('gpu_count', 1))

			train_op = optimizer_fn.get_train_op(loss, tvars,
							opt_config.init_lr, 
							opt_config.num_train_steps,
							use_tpu=opt_config.use_tpu)

			train_metric_dict = train_metric_fn(
					masked_lm_example_loss, masked_lm_log_probs, 
					masked_lm_ids,
					masked_lm_weights, 
					nsp_per_example_loss,
					nsp_log_prob, 
					features['next_sentence_labels'],
					masked_lm_mask=masked_lm_mask
				)

			# for key in train_metric_dict:
			# 	tf.summary.scalar(key, train_metric_dict[key])
			# tf.summary.scalar('learning_rate', optimizer_fn.learning_rate)

			estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
							mode=mode,
							loss=loss,
							train_op=train_op,
							scaffold_fn=scaffold_fn)

			return estimator_spec

		elif mode == tf.estimator.ModeKeys.EVAL:

			def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
					masked_lm_weights, next_sentence_example_loss,
					next_sentence_log_probs, next_sentence_labels):
				"""Computes the loss and accuracy of the model."""
				masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
												 [-1, masked_lm_log_probs.shape[-1]])
				masked_lm_predictions = tf.argmax(
					masked_lm_log_probs, axis=-1, output_type=tf.int32)
				masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
				masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
				masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
				masked_lm_accuracy = tf.metrics.accuracy(
					labels=masked_lm_ids,
					predictions=masked_lm_predictions,
					weights=masked_lm_weights)
				masked_lm_mean_loss = tf.metrics.mean(
					values=masked_lm_example_loss, weights=masked_lm_weights)

				next_sentence_log_probs = tf.reshape(
					next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
				next_sentence_predictions = tf.argmax(
					next_sentence_log_probs, axis=-1, output_type=tf.int32)
				next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
				next_sentence_accuracy = tf.metrics.accuracy(
					labels=next_sentence_labels, predictions=next_sentence_predictions)
				next_sentence_mean_loss = tf.metrics.mean(
					values=next_sentence_example_loss)

				return {
					"masked_lm_accuracy": masked_lm_accuracy,
					"masked_lm_loss": masked_lm_mean_loss,
					"next_sentence_accuracy": next_sentence_accuracy,
					"next_sentence_loss": next_sentence_mean_loss
					}

			eval_metrics = (metric_fn, [
			  masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
			  masked_lm_weights, nsp_per_example_loss,
			  nsp_log_prob, features['next_sentence_labels']
			])

			estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
						  mode=mode,
						  loss=loss,
						  eval_metrics=eval_metrics,
						  scaffold_fn=scaffold_fn)

			return estimator_spec
		else:
			raise NotImplementedError()
예제 #3
0
파일: generator.py 프로젝트: CBHell/BERT
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        (nsp_loss, nsp_per_example_loss,
         nsp_log_prob) = pretrain.get_next_sentence_output(
             model_config,
             model.get_pooled_output(),
             features['next_sentence_labels'],
             reuse=tf.AUTO_REUSE,
             scope='generator')

        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]

        if model_config.model_type == 'bert':
            masked_lm_fn = pretrain.get_masked_lm_output
            print("==apply bert masked lm==")
        elif model_config.model_type == 'albert':
            masked_lm_fn = pretrain_albert.get_masked_lm_output
            print("==apply albert masked lm==")
        else:
            masked_lm_fn = pretrain.get_masked_lm_output
            print("==apply bert masked lm==")

        (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
         masked_lm_mask) = masked_lm_fn(
             model_config,
             model.get_sequence_output(),
             model.get_embedding_table(),
             masked_lm_positions,
             masked_lm_ids,
             masked_lm_weights,
             reuse=tf.AUTO_REUSE,
             embedding_projection=model.get_embedding_projection_table(),
             scope='generator')
        print(model_config.lm_ratio, '==mlm lm_ratio==')
        loss = model_config.lm_ratio * masked_lm_loss  #+ model_config.nsp_ratio * nsp_loss

        sampled_ids = token_generator(
            model_config,
            model.get_sequence_output(),
            model.get_embedding_table(),
            features['input_ids'],
            features['input_ori_ids'],
            features['input_mask'],
            embedding_projection=model.get_embedding_projection_table(),
            scope='generator',
            mask_method='only_mask')

        if model_config.get('gen_sample', 1) == 1:
            input_ids = features['input_ori_ids']
            input_mask = features['input_mask']
            segment_ids = features['segment_ids']
        else:
            input_ids = tf.expand_dims(features['input_ori_ids'], axis=-1)
            # batch x seq_length x 1
            input_ids = tf.einsum(
                'abc,cd->abd', input_ids,
                tf.ones((1, model_config.get('gen_sample', 1))))
            input_ids = tf.cast(input_ids, tf.int32)

            input_shape_list = bert_utils.get_shape_list(input_ids,
                                                         expected_rank=3)
            batch_size = input_shape_list[0]
            seq_length = input_shape_list[1]
            gen_sample = input_shape_list[2]

            sampled_ids = tf.reshape(sampled_ids,
                                     [batch * gen_sample, seq_length])
            input_ids = tf.reshape(input_ids, [batch * gen_sample, seq_length])

            input_mask = tf.expand_dims(features['input_mask'], axis=-1)
            input_mask = tf.einsum(
                'abc,cd->abd', input_mask,
                tf.ones((1, model_config.get('gen_sample', 1))))
            input_mask = tf.cast(input_mask, tf.int32)

            segment_ids = tf.expand_dims(features['segmnet_ids'], axis=-1)
            segment_ids = tf.einsum(
                'abc,cd->abd', segment_ids,
                tf.ones((1, model_config.get('gen_sample', 1))))
            segment_ids = tf.cast(segment_ids, tf.int32)

            segment_ids = tf.reshape(segment_ids,
                                     [batch * gen_sample, seq_length])
            input_mask = tf.reshape(input_mask,
                                    [batch * gen_sample, seq_length])

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        lm_pretrain_tvars = model_io_fn.get_params(
            "generator/cls/predictions", not_storage_params=not_storage_params)

        pretrained_tvars.extend(lm_pretrain_tvars)
        tvars = pretrained_tvars

        print('==generator parameters==', tvars)

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                tvars,
                init_checkpoint,
                exclude_scope="generator",
                use_tpu=use_tpu)
        else:
            scaffold_fn = None

        return_dict = {
            "loss": loss,
            "tvars": tvars,
            "model": model,
            "sampled_ids": sampled_ids,  # batch x gen_sample, seg_length
            "sampled_input_ids": input_ids,  # batch x gen_sample, seg_length,
            "sampled_input_mask": input_mask,
            "sampled_segment_ids": segment_ids,
            "masked_lm_positions": masked_lm_positions,
            "masked_lm_ids": masked_lm_ids,
            "masked_lm_weights": masked_lm_weights,
            "masked_lm_log_probs": masked_lm_log_probs,
            "masked_lm_example_loss": masked_lm_example_loss
        }
        return return_dict
예제 #4
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        if kargs.get('random_generator', '1') == '1':
            if mode in [
                    tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL
            ]:
                input_ori_ids = features['input_ori_ids']

                [output_ids, sampled_binary_mask
                 ] = random_input_ids_generation(model_config,
                                                 features['input_ori_ids'],
                                                 features['input_mask'])
                features['input_ids'] = tf.identity(output_ids)
                tf.logging.info("****** do random generator *******")
            else:
                sampled_binary_mask = None
                output_ids = tf.identity(features['input_ids'])
        else:
            sampled_binary_mask = None
            output_ids = tf.identity(features['input_ids'])

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        (nsp_loss, nsp_per_example_loss,
         nsp_log_prob) = pretrain.get_next_sentence_output(
             model_config,
             model.get_pooled_output(),
             features['next_sentence_labels'],
             reuse=tf.AUTO_REUSE,
             scope=generator_scope_prefix)

        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]

        if model_config.model_type == 'bert':
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output
            print("==apply bert masked lm==")
        elif model_config.model_type == 'albert':
            masked_lm_fn = pretrain_albert.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply albert masked lm==")
        else:
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply bert masked lm==")

        if sampled_binary_mask is not None:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = seq_masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 features['input_mask'],
                 features['input_ori_ids'],
                 features['input_ids'],
                 sampled_binary_mask,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table(),
                 scope=generator_scope_prefix)
            masked_lm_ids = features['input_ori_ids']
        else:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 masked_lm_positions,
                 masked_lm_ids,
                 masked_lm_weights,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table(),
                 scope=generator_scope_prefix)
        print(model_config.lm_ratio, '==mlm lm_ratio==')
        loss = model_config.lm_ratio * masked_lm_loss + 0.0 * nsp_loss

        sampled_ids = token_generator_igr(
            model_config,
            model.get_sequence_output(),
            model.get_embedding_table(),
            features['input_ids'],
            features['input_ori_ids'],
            features['input_mask'],
            embedding_projection=model.get_embedding_projection_table(),
            scope=generator_scope_prefix,
            mask_method='only_mask',
            **kargs)

        if model_config.get('gen_sample', 1) == 1:
            input_ids = features['input_ori_ids']
            input_mask = features['input_mask']
            segment_ids = features['segment_ids']

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        if generator_scope_prefix:
            """
			"generator/cls/predictions"
			"""
            lm_pretrain_tvars = model_io_fn.get_params(
                generator_scope_prefix + "/cls/predictions",
                not_storage_params=not_storage_params)

            nsp_pretrain_vars = model_io_fn.get_params(
                generator_scope_prefix + "/cls/seq_relationship",
                not_storage_params=not_storage_params)
        else:
            lm_pretrain_tvars = model_io_fn.get_params(
                "cls/predictions", not_storage_params=not_storage_params)

            nsp_pretrain_vars = model_io_fn.get_params(
                "cls/seq_relationship", not_storage_params=not_storage_params)

        pretrained_tvars.extend(lm_pretrain_tvars)
        tvars = pretrained_tvars

        print('==generator parameters==', tvars)

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu)
        else:
            scaffold_fn = None

        # tf.add_to_collection("generator_loss", masked_lm_loss)
        return_dict = {
            "loss": loss,
            "tvars": tvars,
            "model": model,
            "sampled_ids": sampled_ids,  # batch x gen_sample, seg_length
            "sampled_input_ids": input_ids,  # batch x gen_sample, seg_length,
            "sampled_input_mask": input_mask,
            "sampled_segment_ids": segment_ids,
            "masked_lm_ids": masked_lm_ids,
            "masked_lm_weights": masked_lm_mask,
            "masked_lm_log_probs": masked_lm_log_probs,
            "masked_lm_example_loss": masked_lm_example_loss,
            "next_sentence_example_loss": nsp_per_example_loss,
            "next_sentence_log_probs": nsp_log_prob,
            "next_sentence_labels": features['next_sentence_labels'],
            "output_ids": output_ids
        }
        return return_dict
예제 #5
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)
        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        (nsp_loss, nsp_per_example_loss,
         nsp_log_prob) = pretrain.get_next_sentence_output(
             model_config,
             model.get_pooled_output(),
             features['next_sentence_labels'],
             reuse=tf.AUTO_REUSE)

        if model_config.model_type == 'bert':
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output
            print("==apply bert masked lm==")
        elif model_config.model_type == 'albert':
            masked_lm_fn = pretrain_albert.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply albert masked lm==")
        else:
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output
            print("==apply bert masked lm==")

        (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
         masked_lm_mask) = seq_masked_lm_fn(
             model_config,
             model.get_sequence_output(),
             model.get_embedding_table(),
             features['input_mask'],
             features['input_ori_ids'],
             features['input_ids'],
             features['input_mask'],
             reuse=tf.AUTO_REUSE,
             embedding_projection=model.get_embedding_projection_table())
        masked_lm_ids = features['input_ori_ids']

        loss = model_config.lm_ratio * masked_lm_loss + 0.0 * nsp_loss

        model_io_fn = model_io.ModelIO(model_io_config)
        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        lm_pretrain_tvars = model_io_fn.get_params(
            "cls/predictions", not_storage_params=not_storage_params)

        nsp_pretrain_vars = model_io_fn.get_params(
            "cls/seq_relationship", not_storage_params=not_storage_params)

        if model_config.get('embedding_scope', None) is not None:
            embedding_tvars = model_io_fn.get_params(
                model_config.get('embedding_scope', 'bert') + "/embeddings",
                not_storage_params=not_storage_params)
            pretrained_tvars.extend(embedding_tvars)

        pretrained_tvars.extend(lm_pretrain_tvars)
        pretrained_tvars.extend(nsp_pretrain_vars)
        tvars = pretrained_tvars

        print('==generator parameters==', tvars)

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu,
                restore_var_name=model_config.get('restore_var_name', []))
        else:
            scaffold_fn = None
        tf.add_to_collection("discriminator_loss", loss)
        return_dict = {
            "loss": loss,
            "tvars": tvars,
            "model": model,
            "per_example_loss": masked_lm_example_loss,
            "masked_lm_weights": masked_lm_mask,
            "masked_lm_log_probs": masked_lm_log_probs,
            "next_sentence_example_loss": nsp_per_example_loss,
            "next_sentence_log_probs": nsp_log_prob,
            "next_sentence_labels": features['next_sentence_labels']
        }
        return return_dict
예제 #6
0
	def model_fn(features, labels, mode, params):

		model_api = model_zoo(model_config)

		model = model_api(model_config, features, labels,
							mode, target, reuse=tf.AUTO_REUSE,
							**kargs)

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope

		(_, 
		 _, 
		 nsp_log_prob) = pretrain.get_next_sentence_output(model_config,
										model.get_pooled_output(),
										features['next_sentence_labels'],
										reuse=tf.AUTO_REUSE)

		with tf.variable_scope('cls/seq_predictions', reuse=tf.AUTO_REUSE):
			(_, 
			logits, 
			_) = classifier(model_config, 
									model.get_sequence_output(),
									features['input_ori_ids'],
									features['input_ids'],
									features['input_mask'],
									2,
									dropout_prob)
									# ,
									# loss='focal_loss')

		# loss += 0.0 * nsp_loss

		model_io_fn = model_io.ModelIO(model_io_config)

		pretrained_tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)
		lm_seq_prediction_tvars = model_io_fn.get_params("cls/seq_predictions", 
									not_storage_params=not_storage_params)
		lm_pretrain_tvars = model_io_fn.get_params("cls/seq_relationship", 
									not_storage_params=not_storage_params)

		pretrained_tvars.extend(lm_seq_prediction_tvars)
		pretrained_tvars.extend(lm_pretrain_tvars)
		tvars = pretrained_tvars

		print('==discriminator parameters==', tvars)

		if load_pretrained == "yes":
			use_tpu = 1 if kargs.get('use_tpu', False) else 0
			scaffold_fn = model_io_fn.load_pretrained(tvars, 
											init_checkpoint,
											exclude_scope=exclude_scope,
											use_tpu=use_tpu)
		else:
			scaffold_fn = None
		
		if mode == tf.estimator.ModeKeys.PREDICT:
			mask = tf.cast(tf.expand_dims(features['input_mask'], axis=-1), tf.float32)
			estimator_spec = tf.estimator.EstimatorSpec(
									mode=mode,
									predictions={
												"probs":tf.nn.softmax(logits)*mask
									},
									export_outputs={
										"output":tf.estimator.export.PredictOutput(
													{
														"probs":tf.nn.softmax(logits)*mask
													}
												)
									}
						)
			return estimator_spec
예제 #7
0
def next_sentence_prediction(model_config, model, features, reuse=None):
	next_sentence_labels = features["next_sentence_label"]

	(next_sentence_loss, next_sentence_example_loss,
     next_sentence_log_probs) = pretrain.get_next_sentence_output(
         model_config, model.get_pooled_output(), next_sentence_labels, reuse=None)
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=model_reuse)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        (nsp_loss, nsp_per_example_loss,
         nsp_log_prob) = pretrain.get_next_sentence_output(
             model_config,
             model.get_pooled_output(),
             features['next_sentence_labels'],
             reuse=model_reuse)

        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]
        (masked_lm_loss, masked_lm_example_loss,
         masked_lm_log_probs) = pretrain.get_masked_lm_output(
             model_config,
             model.get_sequence_output(),
             model.get_embedding_table(),
             masked_lm_positions,
             masked_lm_ids,
             masked_lm_weights,
             reuse=model_reuse)
        loss = model_config.lm_ratio * masked_lm_loss + model_config.nsp_ratio * nsp_loss

        model_io_fn = model_io.ModelIO(model_io_config)

        if mode == tf.estimator.ModeKeys.TRAIN:
            pretrained_tvars = model_io_fn.get_params(
                model_config.scope, not_storage_params=not_storage_params)

            lm_pretrain_tvars = model_io_fn.get_params(
                "cls", not_storage_params=not_storage_params)

            pretrained_tvars.extend(lm_pretrain_tvars)

            optimizer_fn = optimizer.Optimizer(opt_config)

            if load_pretrained:
                model_io_fn.load_pretrained(pretrained_tvars,
                                            init_checkpoint,
                                            exclude_scope=exclude_scope)

            tvars = pretrained_tvars
            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):

                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps)

                train_op, hooks = model_io_fn.get_ema_hooks(
                    train_op,
                    tvars,
                    kargs.get('params_moving_average_decay', 0.99),
                    scope,
                    mode,
                    first_stage_steps=opt_config.num_warmup_steps,
                    two_stage=True)

                model_io_fn.set_saver()

                train_metric_dict = train_metric_fn(
                    masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                    masked_lm_weights, nsp_per_example_loss, nsp_log_prob,
                    features['next_sentence_labels'])

                for key in train_metric_dict:
                    tf.summary.scalar(key, train_metric_dict[key])
                tf.summary.scalar('learning_rate', optimizer_fn.learning_rate)

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    training_hooks = []
                elif kargs.get("task_index", 1) == 0:
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)

                if output_type == "sess":
                    return {
                        "train": {
                            "loss": loss,
                            "nsp_log_pro": nsp_log_prob,
                            "train_op": train_op,
                            "masked_lm_loss": masked_lm_loss,
                            "next_sentence_loss": nsp_loss,
                            "masked_lm_log_pro": masked_lm_log_probs
                        },
                        "hooks": training_hooks
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:

            def prediction_fn(logits):

                predictions = {
                    "nsp_classes":
                    tf.argmax(input=nsp_log_prob, axis=1),
                    "nsp_probabilities":
                    tf.exp(nsp_log_prob, name="nsp_softmax"),
                    "masked_vocab_classes":
                    tf.argmax(input=masked_lm_log_probs, axis=1),
                    "masked_probabilities":
                    tf.exp(masked_lm_log_probs, name='masked_softmax')
                }
                return predictions

            predictions = prediction_fn(logits)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=predictions,
                export_outputs={
                    "output": tf.estimator.export.PredictOutput(predictions)
                })
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_log_probs,
                          masked_lm_ids, masked_lm_weights,
                          next_sentence_example_loss, next_sentence_log_probs,
                          next_sentence_labels):
                """Computes the loss and accuracy of the model."""
                masked_lm_log_probs = tf.reshape(
                    masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                next_sentence_log_probs = tf.reshape(
                    next_sentence_log_probs,
                    [-1, next_sentence_log_probs.shape[-1]])
                next_sentence_predictions = tf.argmax(next_sentence_log_probs,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
                next_sentence_accuracy = tf.metrics.accuracy(
                    labels=next_sentence_labels,
                    predictions=next_sentence_predictions)
                next_sentence_mean_loss = tf.metrics.mean(
                    values=next_sentence_example_loss)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "next_sentence_accuracy": next_sentence_accuracy,
                    "next_sentence_loss": next_sentence_mean_loss
                }

            if output_type == "sess":
                return {
                    "eval": {
                        "nsp_log_prob": nsp_log_prob,
                        "masked_lm_log_prob": masked_lm_log_probs,
                        "nsp_loss": nsp_loss,
                        "masked_lm_loss": masked_lm_loss,
                        "feature": model.get_pooled_output()
                    }
                }
            elif output_type == "estimator":
                eval_metric_ops = metric_fn(masked_lm_example_loss,
                                            masked_lm_log_probs, masked_lm_ids,
                                            masked_lm_weights,
                                            nsp_per_example_loss, nsp_log_prob,
                                            features['next_sentence_labels'])
                _, hooks = model_io_fn.get_ema_hooks(
                    None, None, kargs.get('params_moving_average_decay', 0.99),
                    scope, mode)

                eval_hooks = [hooks] if hooks else []

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metric_ops=eval_metric_ops,
                    evaluation_hooks=eval_hooks)
                return estimator_spec
        else:
            raise NotImplementedError()
예제 #9
0
	def model_fn(features, labels, mode, params):

		model_api = model_zoo(model_config)

		model = model_api(model_config, features, labels,
							mode, target, reuse=tf.AUTO_REUSE,
							**kargs)

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope

		(nsp_loss, 
		 nsp_per_example_loss, 
		 nsp_log_prob) = pretrain.get_next_sentence_output(model_config,
										model.get_pooled_output(),
										features['next_sentence_labels'],
										reuse=tf.AUTO_REUSE)

		with tf.variable_scope('cls/seq_predictions', reuse=tf.AUTO_REUSE):
			(loss, 
			logits, 
			per_example_loss) = classifier(model_config, 
									model.get_sequence_output(),
									features['input_ori_ids'],
									features['ori_input_ids'],
									features['input_mask'],
									2,
									dropout_prob,
									ori_sampled_ids=features.get('ori_sampled_ids', None),
									use_tpu=kargs.get('use_tpu', True))
	
		tf.add_to_collection("discriminator_loss", loss)
		loss += 0.0 * nsp_loss

		model_io_fn = model_io.ModelIO(model_io_config)

		pretrained_tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)
		lm_seq_prediction_tvars = model_io_fn.get_params("cls/seq_predictions", 
									not_storage_params=not_storage_params)
		lm_pretrain_tvars = model_io_fn.get_params("cls/seq_relationship", 
									not_storage_params=not_storage_params)

		pretrained_tvars.extend(lm_seq_prediction_tvars)
		pretrained_tvars.extend(lm_pretrain_tvars)
		tvars = pretrained_tvars

		print('==discriminator parameters==', tvars)

		if load_pretrained == "yes":
			use_tpu = 1 if kargs.get('use_tpu', False) else 0
			scaffold_fn = model_io_fn.load_pretrained(tvars, 
											init_checkpoint,
											exclude_scope=exclude_scope,
											use_tpu=use_tpu,
											restore_var_name=model_config.get('restore_var_name', []))
		else:
			scaffold_fn = None
		return_dict = {
					"loss":loss, 
					"logits":logits,
					"tvars":tvars,
					"model":model,
					"per_example_loss":per_example_loss
				}
		return return_dict
예제 #10
0
	def model_fn(features, labels, mode, params):

		model_api = model_zoo(model_config)

		model = model_api(model_config, features, labels,
							mode, target, reuse=tf.AUTO_REUSE,
							**kargs)

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope
		
		(nsp_loss, 
		 nsp_per_example_loss, 
		 nsp_log_prob) = pretrain.get_next_sentence_output(model_config,
										model.get_pooled_output(),
										features['next_sentence_labels'],
										reuse=tf.AUTO_REUSE,
										scope=generator_scope_prefix)

		if model_config.model_type == 'bert':
			masked_lm_fn = pretrain.get_masked_lm_output
			seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output
			print("==apply bert masked lm==")
		elif model_config.model_type == 'albert':
			masked_lm_fn = pretrain_albert.get_masked_lm_output
			seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
			print("==apply albert masked lm==")
		else:
			masked_lm_fn = pretrain.get_masked_lm_output
			seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
			print("==apply bert masked lm==")

		(_,
			_, 
			masked_lm_log_probs,
			_) = seq_masked_lm_fn(model_config, 
										model.get_sequence_output(), 
										model.get_embedding_table(),
										features['input_mask'], 
										features['input_ori_ids'], 
										features['input_ids'],
										features['input_mask'],
										reuse=tf.AUTO_REUSE,
										embedding_projection=model.get_embedding_projection_table(),
										scope=generator_scope_prefix)

		print(model_config.lm_ratio, '==mlm lm_ratio==')
		# loss = model_config.lm_ratio * masked_lm_loss + 0.0 * nsp_loss

		model_io_fn = model_io.ModelIO(model_io_config)

		pretrained_tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)

		if generator_scope_prefix:
			"""
			"generator/cls/predictions"
			"""
			lm_pretrain_tvars = model_io_fn.get_params(generator_scope_prefix+"/cls/predictions", 
										not_storage_params=not_storage_params)

			nsp_pretrain_vars = model_io_fn.get_params(generator_scope_prefix+"/cls/seq_relationship",
										not_storage_params=not_storage_params)
		else:
			lm_pretrain_tvars = model_io_fn.get_params("cls/predictions", 
										not_storage_params=not_storage_params)

			nsp_pretrain_vars = model_io_fn.get_params("cls/seq_relationship",
										not_storage_params=not_storage_params)

		pretrained_tvars.extend(lm_pretrain_tvars)
		pretrained_tvars.extend(nsp_pretrain_vars)
		tvars = pretrained_tvars

		print('==generator parameters==', tvars)

		if load_pretrained == "yes":
			use_tpu = 1 if kargs.get('use_tpu', False) else 0
			scaffold_fn = model_io_fn.load_pretrained(tvars, 
											init_checkpoint,
											exclude_scope=exclude_scope,
											use_tpu=use_tpu)
		else:
			scaffold_fn = None

		if mode == tf.estimator.ModeKeys.PREDICT:
			mask = tf.expand_dims(tf.cast(features['input_mask'], tf.float32), axis=-1)
			estimator_spec = tf.estimator.EstimatorSpec(
									mode=mode,
									predictions={
												"probs":mask*tf.exp(tf.nn.log_softmax(masked_lm_log_probs))
									},
									export_outputs={
										"output":tf.estimator.export.PredictOutput(
													{
														"probs":mask*tf.exp(tf.nn.log_softmax(masked_lm_log_probs))
													}
												)
									}
						)
			return estimator_spec
예제 #11
0
	def model_fn(features, labels, mode, params):

		model_api = model_zoo(model_config)

		input_ori_ids = features.get('input_ori_ids', None)
		if mode == tf.estimator.ModeKeys.TRAIN:
			if input_ori_ids is not None:
				# [output_ids, 
				# sampled_binary_mask] = random_input_ids_generation(
				# 							model_config, 
				# 							input_ori_ids,
				# 							features['input_mask'],
				# 							**kargs)

				[output_ids, 
				sampled_binary_mask] = hmm_input_ids_generation(model_config,
											features['input_ori_ids'],
											features['input_mask'],
											[tf.cast(tf.constant(hmm_tran_prob), tf.float32) for hmm_tran_prob in hmm_tran_prob_list],
											mask_probability=0.2,
											replace_probability=0.1,
											original_probability=0.1,
											mask_prior=tf.cast(tf.constant(mask_prior), tf.float32),
											**kargs)

				features['input_ids'] = output_ids
				tf.logging.info("***** Running random sample input generation *****")
			else:
				sampled_binary_mask = None
		else:
			sampled_binary_mask = None

		model = model_api(model_config, features, labels,
							mode, target, reuse=tf.AUTO_REUSE,
							**kargs)

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope
		
		(nsp_loss, 
		 nsp_per_example_loss, 
		 nsp_log_prob) = pretrain.get_next_sentence_output(model_config,
										model.get_pooled_output(),
										features['next_sentence_labels'],
										reuse=tf.AUTO_REUSE)

		masked_lm_positions = features["masked_lm_positions"]
		masked_lm_ids = features["masked_lm_ids"]
		masked_lm_weights = features["masked_lm_weights"]

		if model_config.model_type == 'bert':
			masked_lm_fn = pretrain.get_masked_lm_output
			seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output
			print("==apply bert masked lm==")
		elif model_config.model_type == 'albert':
			masked_lm_fn = pretrain_albert.get_masked_lm_output
			seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
			print("==apply albert masked lm==")
		else:
			masked_lm_fn = pretrain.get_masked_lm_output
			seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
			print("==apply bert masked lm==")

		if sampled_binary_mask is not None:
			(masked_lm_loss,
			masked_lm_example_loss, 
			masked_lm_log_probs,
			masked_lm_mask) = seq_masked_lm_fn(model_config, 
										model.get_sequence_output(), 
										model.get_embedding_table(),
										features['input_mask'], 
										features['input_ori_ids'], 
										features['input_ids'],
										sampled_binary_mask,
										reuse=tf.AUTO_REUSE,
										embedding_projection=model.get_embedding_projection_table())
			masked_lm_ids = input_ori_ids
		else:
			(masked_lm_loss,
			masked_lm_example_loss, 
			masked_lm_log_probs,
			masked_lm_mask) = masked_lm_fn(
											model_config, 
											model.get_sequence_output(), 
											model.get_embedding_table(),
											masked_lm_positions, 
											masked_lm_ids, 
											masked_lm_weights,
											reuse=tf.AUTO_REUSE,
											embedding_projection=model.get_embedding_projection_table())
		print(model_config.lm_ratio, '==mlm lm_ratio==')
		loss = model_config.lm_ratio * masked_lm_loss #+ 0.0 * nsp_loss
		
		model_io_fn = model_io.ModelIO(model_io_config)

		pretrained_tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)

		lm_pretrain_tvars = model_io_fn.get_params("cls/predictions", 
									not_storage_params=not_storage_params)

		pretrained_tvars.extend(lm_pretrain_tvars)

		if load_pretrained == "yes":
			scaffold_fn = model_io_fn.load_pretrained(pretrained_tvars, 
											init_checkpoint,
											exclude_scope=exclude_scope,
											use_tpu=1)
		else:
			scaffold_fn = None

		if mode == tf.estimator.ModeKeys.TRAIN:
						
			optimizer_fn = optimizer.Optimizer(opt_config)
						
			tvars = pretrained_tvars
			model_io_fn.print_params(tvars, string=", trainable params")
			
			update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			with tf.control_dependencies(update_ops):
				train_op = optimizer_fn.get_train_op(loss, tvars,
								opt_config.init_lr, 
								opt_config.num_train_steps,
								use_tpu=opt_config.use_tpu)

				train_metric_dict = train_metric_fn(
						masked_lm_example_loss, masked_lm_log_probs, 
						masked_lm_ids,
						masked_lm_mask, 
						nsp_per_example_loss,
						nsp_log_prob, 
						features['next_sentence_labels'],
						masked_lm_mask=masked_lm_mask
					)

				# for key in train_metric_dict:
				# 	tf.summary.scalar(key, train_metric_dict[key])
				# tf.summary.scalar('learning_rate', optimizer_fn.learning_rate)

				estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
								mode=mode,
								loss=loss,
								train_op=train_op,
								scaffold_fn=scaffold_fn)

				return estimator_spec

		elif mode == tf.estimator.ModeKeys.EVAL:

			def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
					masked_lm_weights, next_sentence_example_loss,
					next_sentence_log_probs, next_sentence_labels):
				"""Computes the loss and accuracy of the model."""
				masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
												 [-1, masked_lm_log_probs.shape[-1]])
				masked_lm_predictions = tf.argmax(
					masked_lm_log_probs, axis=-1, output_type=tf.int32)
				masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
				masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
				masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
				masked_lm_accuracy = tf.metrics.accuracy(
					labels=masked_lm_ids,
					predictions=masked_lm_predictions,
					weights=masked_lm_weights)
				masked_lm_mean_loss = tf.metrics.mean(
					values=masked_lm_example_loss, weights=masked_lm_weights)

				next_sentence_log_probs = tf.reshape(
					next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
				next_sentence_predictions = tf.argmax(
					next_sentence_log_probs, axis=-1, output_type=tf.int32)
				next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
				next_sentence_accuracy = tf.metrics.accuracy(
					labels=next_sentence_labels, predictions=next_sentence_predictions)
				next_sentence_mean_loss = tf.metrics.mean(
					values=next_sentence_example_loss)

				return {
					"masked_lm_accuracy": masked_lm_accuracy,
					"masked_lm_loss": masked_lm_mean_loss,
					"next_sentence_accuracy": next_sentence_accuracy,
					"next_sentence_loss": next_sentence_mean_loss
					}

			eval_metrics = (metric_fn, [
			  masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
			  masked_lm_mask, nsp_per_example_loss,
			  nsp_log_prob, features['next_sentence_labels']
			])

			estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
						  mode=mode,
						  loss=loss,
						  eval_metrics=eval_metrics,
						  scaffold_fn=scaffold_fn)

			return estimator_spec
		else:
			raise NotImplementedError()