Exemple #1
0
    def model_fn(features, labels, mode, params):
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        tag_ids = features["tag_ids"]
        is_real_example = None
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(tag_ids), dtype=tf.float32)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        total_loss, logits, predictions = create_model(
            bert_config, is_training, input_ids, input_mask, segment_ids, tag_ids,
            num_labels, use_one_hot_embeddings)

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
            if use_tpu:
                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        if mode == tf.estimator.ModeKeys.TRAIN:
            # 添加loss的hook,不然在GPU/CPU上不打印loss
            logging_hook = tf.train.LoggingTensorHook({"loss": total_loss}, every_n_iter=10)
            train_op = optimization.create_optimizer(
                total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                training_hooks=[logging_hook],
                scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            def metric_fn(per_example_loss, tag_ids, is_real_example):
                # 这里使用的accuracy来计算,宽松匹配方法
                accuracy = tf.metrics.accuracy(
                    labels=tag_ids, predictions=predictions, weights=is_real_example)
                return {
                    "eval_accuracy": accuracy,
                }

            eval_metrics = (metric_fn,
                            [total_loss, tag_ids, is_real_example])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                predictions={"predictions": predictions},
                scaffold_fn=scaffold_fn)
        return output_spec
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]
        is_real_example = None
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        (total_loss, per_example_loss, logits, probabilities) = create_model(
            bert_config, is_training, input_ids, input_mask, segment_ids, label_ids,
            num_labels, use_one_hot_embeddings)

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:

            train_op = optimization.create_optimizer(
                total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)

            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, label_ids, logits, is_real_example):
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=predictions, weights=is_real_example)
                loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
                return {
                    "eval_accuracy": accuracy,
                    "eval_loss": loss,
                }

            eval_metrics = (metric_fn,
                            [per_example_loss, label_ids, logits, is_real_example])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                predictions={"probabilities": probabilities},
                scaffold_fn=scaffold_fn)
        return output_spec
def main():
    """ 训练主入口 """
    tf.logging.info('start to train')

    # 部分参数设置
    process = AllProcessor()
    label_list = process.get_labels()
    tokenizer = tokenization.FullTokenizer(
        vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

    train_examples = process.get_train_examples(FLAGS.data_dir)
    train_cnt = file_based_convert_examples_to_features(
        train_examples,
        label_list,
        FLAGS.max_seq_length,
        tokenizer,
        FLAGS.data_dir,
        'train'
    )
    dev_examples = process.get_dev_examples(FLAGS.data_dir)
    dev_cnt = file_based_convert_examples_to_features(
        dev_examples,
        label_list,
        FLAGS.max_seq_length,
        tokenizer,
        FLAGS.data_dir,
        'dev'
    )

    # 输入输出定义
    input_ids = tf.placeholder(tf.int64, shape=[None, FLAGS.max_seq_length],
                               name='input_ids')
    input_mask = tf.placeholder(tf.int64, shape=[None, FLAGS.max_seq_length],
                                name='input_mask')
    segment_ids = tf.placeholder(tf.int64, shape=[None, FLAGS.max_seq_length],
                                 name='segment_ids')
    labels = tf.placeholder(tf.int64, shape=[None], name='labels')
    task = tf.placeholder(tf.int64, name='task')

    # bert相关参数设置
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    loss, logits, acc, pre_id = create_model(
        bert_config,
        True,
        input_ids,
        input_mask,
        segment_ids,
        labels,
        False,
        task
    )
    num_train_steps = int(len(train_examples) / FLAGS.train_batch_size)
    num_warmup_steps = math.ceil(
        num_train_steps * FLAGS.train_batch_size * FLAGS.warmup_proportion)
    train_op = optimization.create_optimizer(
        loss,
        FLAGS.learning_rate,
        num_train_steps * FLAGS.num_train_epochs,
        num_warmup_steps,
        False
    )

    # 初始化参数
    init_global = tf.global_variables_initializer()
    saver = tf.train.Saver(
        [v for v in tf.global_variables()
         if 'adam_v' not in v.name and 'adam_m' not in v.name])

    with tf.Session() as sess:
        sess.run(init_global)
        print('start to load bert params')
        if FLAGS.init_checkpoint:
            # tvars = tf.global_variables()
            tvars = tf.trainable_variables()
            print("global_variables", len(tvars))
            assignment_map, initialized_variable_names = \
                modeling.get_assignment_map_from_checkpoint(tvars,
                                                            FLAGS.init_checkpoint)
            print("initialized_variable_names:", len(initialized_variable_names))
            saver_ = tf.train.Saver([v for v in tvars if v.name in initialized_variable_names])
            saver_.restore(sess, FLAGS.init_checkpoint)
            tvars = tf.global_variables()
            # initialized_vars = [v for v in tvars if v.name in initialized_variable_names]
            not_initialized_vars = [v for v in tvars if v.name not in initialized_variable_names]
            print('all size %s; not initialized size %s' % (len(tvars), len(not_initialized_vars)))
            if len(not_initialized_vars):
                sess.run(tf.variables_initializer(not_initialized_vars))
            # for v in initialized_vars:
            #     print('initialized: %s, shape = %s' % (v.name, v.shape))
            # for v in not_initialized_vars:
            #     print('not initialized: %s, shape = %s' % (v.name, v.shape))
        else:
            print('the bert init checkpoint is None!!!')
            sess.run(tf.global_variables_initializer())

        # 训练的step
        def train_step(ids, mask, seg, true_y, task_id):
            feed = {input_ids: ids,
                    input_mask: mask,
                    segment_ids: seg,
                    labels: true_y,
                    task: task_id}
            _, logits_out, loss_out = sess.run([train_op, logits, loss], feed_dict=feed)
            return logits_out, loss_out

        # 验证的step
        def dev_step(ids, mask, seg, true_y, task_id):
            feed = {input_ids: ids,
                    input_mask: mask,
                    segment_ids: seg,
                    labels: true_y,
                    task: task_id}
            pre_out, acc_out = sess.run([pre_id, acc], feed_dict=feed)
            return pre_out, acc_out

        # 开始训练
        for epoch in range(FLAGS.num_train_epochs):
            tf.logging.info(f'start to train and the epoch:{epoch}')
            epoch_loss = do_train(sess, train_cnt, train_step, epoch)
            tf.logging.info(f'the epoch{epoch} loss is {epoch_loss}')
            saver.save(sess, FLAGS.output_dir + 'bert.ckpt', global_step=epoch)
            # 每一个epoch开始验证模型
            do_eval(sess, dev_cnt, dev_step)

        # 进行预测并保存结果
        do_predict(label_list, process, tokenizer, dev_step)

        tf.logging.info('the training is over!!!!')