def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]
        next_sentence_labels = features["next_sentence_labels"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        if bert_teacher_config is None:
            model = modeling.BertModel(
                config=bert_config,
                is_training=is_training,
                input_ids=input_ids,
                input_mask=input_mask,
                token_type_ids=segment_ids,
                use_one_hot_embeddings=use_one_hot_embeddings,
                use_einsum=use_einsum)

            label_ids = tf.reshape(masked_lm_ids, [-1])
            true_labels = tf.one_hot(label_ids,
                                     depth=bert_config.vocab_size,
                                     dtype=model.get_sequence_output().dtype)
            one_hot_labels = true_labels
        else:
            model = modeling.BertModel(
                config=bert_config,
                is_training=False,
                input_ids=input_ids,
                input_mask=input_mask,
                token_type_ids=segment_ids,
                use_one_hot_embeddings=use_one_hot_embeddings,
                use_einsum=use_einsum)

            with tf.variable_scope("teacher"):
                teacher_model = modeling.BertModel(
                    config=bert_teacher_config,
                    is_training=False,
                    input_ids=input_ids,
                    input_mask=input_mask,
                    token_type_ids=segment_ids,
                    use_one_hot_embeddings=use_one_hot_embeddings,
                    use_einsum=use_einsum)

                label_ids = tf.reshape(masked_lm_ids, [-1])

                true_labels = tf.one_hot(
                    label_ids,
                    depth=bert_config.vocab_size,
                    dtype=model.get_sequence_output().dtype)

                teacher_logits = get_logits(
                    bert_teacher_config,
                    distill_temperature * teacher_model.get_sequence_output(),
                    teacher_model.get_embedding_table(), masked_lm_positions)

                teacher_labels = tf.nn.softmax(teacher_logits, axis=-1)

                if distill_ground_truth_ratio == 1.0:
                    one_hot_labels = true_labels
                else:
                    one_hot_labels = (teacher_labels *
                                      (1 - distill_ground_truth_ratio) +
                                      true_labels * distill_ground_truth_ratio)

                teacher_attentions = teacher_model.get_all_attention_maps()
                student_attentions = model.get_all_attention_maps()

                teacher_hiddens = teacher_model.get_all_encoder_layers()
                student_hiddens = model.get_all_encoder_layers()

        (masked_lm_loss, _, masked_lm_example_loss, masked_lm_log_probs,
         _) = get_masked_lm_output(bert_config, model.get_sequence_output(),
                                   model.get_embedding_table(),
                                   masked_lm_positions,
                                   tf.stop_gradient(one_hot_labels),
                                   true_labels, masked_lm_weights)

        (next_sentence_loss, next_sentence_example_loss,
         next_sentence_log_probs) = get_next_sentence_output(
             bert_config, model.get_pooled_output(), next_sentence_labels)

        extra_loss1 = 0.0
        extra_loss2 = 0.0
        extra_loss3 = 0.0
        extra_loss4 = 0.0

        scalars_to_summarize = {}

        def get_layerwise_gate(layer_id):
            steps_per_phase = num_train_steps // bert_config.num_hidden_layers
            layer_wise_gate = distill_util.layer_wise_learning_rate(
                layer_id=layer_id,
                steps_per_phase=steps_per_phase,
                binary=True)
            return layer_wise_gate

        if layer_wise_warmup and hidden_distill_factor != 0.0:
            layer_id = 0
            for teacher_hidden, student_hidden in (zip(teacher_hiddens[1:],
                                                       student_hiddens[1:])):
                with tf.variable_scope("hidden_distill_%d" % layer_id):
                    mse_loss = tf.losses.mean_squared_error(
                        tf.stop_gradient(
                            contrib_layers.layer_norm(inputs=teacher_hidden,
                                                      begin_norm_axis=-1,
                                                      begin_params_axis=-1,
                                                      trainable=False)),
                        contrib_layers.layer_norm(inputs=student_hidden,
                                                  begin_norm_axis=-1,
                                                  begin_params_axis=-1,
                                                  trainable=False))
                    layer_wise_gate = get_layerwise_gate(layer_id)
                    extra_loss1 += layer_wise_gate * mse_loss
                layer_id += 1
            extra_loss1 = extra_loss1 * hidden_distill_factor / layer_id

        if layer_wise_warmup and (beta_distill_factor != 0
                                  and gamma_distill_factor != 0.0):
            layer_id = 0
            for teacher_hidden, student_hidden in (zip(teacher_hiddens[1:],
                                                       student_hiddens[1:])):
                with tf.variable_scope("hidden_distill_%d" % layer_id):
                    teacher_mean = tf.reduce_mean(teacher_hiddens,
                                                  axis=[-1],
                                                  keepdims=True)
                    student_mean = tf.reduce_mean(student_hidden,
                                                  axis=[-1],
                                                  keepdims=True)
                    teacher_variance = tf.reduce_mean(tf.squared_difference(
                        teacher_hiddens, teacher_mean),
                                                      axis=[-1],
                                                      keepdims=True)
                    student_variance = tf.reduce_mean(tf.squared_difference(
                        student_hidden, student_mean),
                                                      axis=[-1],
                                                      keepdims=True)
                    beta_distill_loss = tf.reduce_mean(
                        tf.squared_difference(tf.stop_gradient(teacher_mean),
                                              student_mean))
                    gamma_distill_loss = tf.reduce_mean(
                        tf.abs(
                            tf.stop_gradient(teacher_variance) -
                            student_variance))
                    layer_wise_gate = get_layerwise_gate(layer_id)
                    extra_loss3 += layer_wise_gate * beta_distill_loss
                    extra_loss4 += layer_wise_gate * gamma_distill_loss
                layer_id += 1
            extra_loss3 = extra_loss3 * beta_distill_factor / layer_id
            extra_loss4 = extra_loss4 * gamma_distill_factor / layer_id

        if layer_wise_warmup and attention_distill_factor != 0.0:
            layer_id = 0
            for teacher_attention, student_attention in (zip(
                    teacher_attentions, student_attentions)):
                with tf.variable_scope("attention_distill_%d" % layer_id):
                    teacher_attention_prob = tf.nn.softmax(teacher_attention,
                                                           axis=-1)
                    student_attention_log_prob = tf.nn.log_softmax(
                        student_attention, axis=-1)
                    kl_divergence = -(tf.stop_gradient(teacher_attention_prob)
                                      * student_attention_log_prob)
                    kl_divergence = tf.reduce_mean(
                        tf.reduce_sum(kl_divergence, axis=-1))
                    layer_wise_gate = get_layerwise_gate(layer_id)
                    extra_loss2 += layer_wise_gate * kl_divergence
                layer_id += 1
            extra_loss2 = extra_loss2 * attention_distill_factor / layer_id

        if layer_wise_warmup:
            total_loss = extra_loss1 + extra_loss2 + extra_loss3 + extra_loss4
        else:
            total_loss = masked_lm_loss + next_sentence_loss

        if summary_dir is not None:
            if layer_wise_warmup:
                scalars_to_summarize["feature_map_transfer_loss"] = extra_loss1
                scalars_to_summarize["attention_transfer_loss"] = extra_loss2
                scalars_to_summarize["mean_transfer_loss"] = extra_loss3
                scalars_to_summarize["variance_transfer_loss"] = extra_loss4
            else:
                scalars_to_summarize["masked_lm_loss"] = masked_lm_loss
                scalars_to_summarize["next_sentence_loss"] = next_sentence_loss

                masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_accuracy = tf.cast(
                    tf.math.equal(tf.reshape(masked_lm_ids, [-1]),
                                  tf.reshape(masked_lm_predictions, [-1])),
                    tf.float32)
                numerator = tf.reduce_sum(
                    tf.reshape(masked_lm_weights, [-1]) * masked_lm_accuracy)
                denominator = tf.reduce_sum(masked_lm_weights) + 1e-5
                masked_lm_accuracy = numerator / denominator
                scalars_to_summarize["masked_lm_accuracy"] = masked_lm_accuracy

                next_sentence_predictions = tf.argmax(next_sentence_log_probs,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_accuracy = tf.reduce_mean(
                    tf.cast(
                        tf.math.equal(
                            tf.reshape(next_sentence_labels, [-1]),
                            tf.reshape(next_sentence_predictions, [-1])),
                        tf.float32))
                scalars_to_summarize[
                    "next_sentence_accuracy"] = next_sentence_accuracy

            scalars_to_summarize[
                "global_step"] = tf.train.get_or_create_global_step()
            scalars_to_summarize["loss"] = total_loss

        host_call = None
        if summary_dir is not None:
            if use_tpu:
                for name in scalars_to_summarize:
                    scalars_to_summarize[name] = tf.reshape(
                        scalars_to_summarize[name], [1])

                def host_call_fn(*args):
                    """Host call function to compute training summaries."""
                    scalars = _list_to_dicts(args,
                                             scalars_to_summarize.keys())[0]
                    for name in scalars:
                        scalars[name] = scalars[name][0]

                    with contrib_summary.create_file_writer(
                            summary_dir, max_queue=1000).as_default():
                        with contrib_summary.always_record_summaries():
                            for name, value in scalars.items():
                                if name not in ["global_step"]:
                                    contrib_summary.scalar(
                                        name,
                                        value,
                                        step=scalars["global_step"])

                    return contrib_summary.all_summary_ops()

                host_call = (host_call_fn,
                             _dicts_to_list([scalars_to_summarize],
                                            scalars_to_summarize.keys()))
            else:
                for name in scalars_to_summarize:
                    tf.summary.scalar(name, scalars_to_summarize[name])

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        teacher_initialized_variable_names = {}
        scaffold_fn = None

        if init_checkpoint:
            if not init_from_teacher:
                # Initializes from the checkpoint for all variables.
                (assignment_map, initialized_variable_names
                 ) = modeling.get_assignment_map_from_checkpoint(
                     tvars, init_checkpoint)
                if use_tpu:

                    def tpu_scaffold():
                        tf.train.init_from_checkpoint(init_checkpoint,
                                                      assignment_map)
                        return tf.train.Scaffold()

                    scaffold_fn = tpu_scaffold
                else:
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
            elif bert_teacher_config is not None:
                # Initializes from the pre-trained checkpoint only for teacher model
                # and embeddings for distillation.
                (assignment_map, initialized_variable_names
                 ) = modeling.get_assignment_map_from_checkpoint(
                     tvars, init_checkpoint, init_embedding=True)
                (teacher_assignment_map, teacher_initialized_variable_names
                 ) = modeling.get_assignment_map_from_checkpoint(
                     tvars, init_checkpoint, init_from_teacher=True)
                if use_tpu:

                    def teacher_tpu_scaffold():
                        tf.train.init_from_checkpoint(init_checkpoint,
                                                      assignment_map)
                        tf.train.init_from_checkpoint(init_checkpoint,
                                                      teacher_assignment_map)
                        return tf.train.Scaffold()

                    scaffold_fn = teacher_tpu_scaffold
                else:
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  teacher_assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        total_size = 0
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            if var.name in teacher_initialized_variable_names:
                init_string = ", *INIT_FROM_TEACHER_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)
            if not var.name.startswith("teacher"):
                total_size += functools.reduce(lambda x, y: x * y,
                                               var.get_shape().as_list())
        tf.logging.info("  total variable parameters: %d", total_size)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            if layer_wise_warmup:
                train_op = optimization.create_optimizer(
                    total_loss,
                    learning_rate,
                    num_train_steps,
                    num_warmup_steps,
                    use_tpu,
                    optimizer,
                    end_lr_rate=1.0,
                    use_layer_wise_warmup=True,
                    total_warmup_phases=bert_config.num_hidden_layers)
            else:
                train_op = optimization.create_optimizer(
                    total_loss, learning_rate, num_train_steps,
                    num_warmup_steps, use_tpu, optimizer)

            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn,
                host_call=host_call)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_log_probs,
                          masked_lm_ids, masked_lm_weights,
                          next_sentence_example_loss, next_sentence_log_probs,
                          next_sentence_labels):
                """Computes the loss and accuracy of the model."""
                masked_lm_log_probs = tf.reshape(
                    masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                next_sentence_log_probs = tf.reshape(
                    next_sentence_log_probs,
                    [-1, next_sentence_log_probs.shape[-1]])
                next_sentence_predictions = tf.argmax(next_sentence_log_probs,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
                next_sentence_accuracy = tf.metrics.accuracy(
                    labels=next_sentence_labels,
                    predictions=next_sentence_predictions)
                next_sentence_mean_loss = tf.metrics.mean(
                    values=next_sentence_example_loss)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "next_sentence_accuracy": next_sentence_accuracy,
                    "next_sentence_loss": next_sentence_mean_loss,
                }

            eval_metrics = (metric_fn, [
                masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                masked_lm_weights, next_sentence_example_loss,
                next_sentence_log_probs, next_sentence_labels
            ])
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            raise ValueError("Only TRAIN and EVAL modes are supported: %s" %
                             (mode))

        return output_spec
示例#2
0
    def model_fn(features, labels, mode, params):
        logging.info("*** Features ***")
        for name in sorted(features.keys()):
            logging.info("  name = %s, shape = %s" % (name, features[name].shape))
        input_ids = features["input_ids"]
        mask = features["mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        if FLAGS.crf:
            (total_loss, logits,predicts) = create_model(bert_config, is_training, input_ids,
                                                            mask, segment_ids, label_ids,num_labels, 
                                                            use_one_hot_embeddings)

        else:
            (total_loss, logits, predicts) = create_model(bert_config, is_training, input_ids,
                                                            mask, segment_ids, label_ids,num_labels, 
                                                            use_one_hot_embeddings)
        tvars = tf.trainable_variables()
        scaffold_fn = None
        initialized_variable_names=None
        if init_checkpoint:
            (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,init_checkpoint)
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
            if use_tpu:
                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
                    return tf.train.Scaffold()
                scaffold_fn = tpu_scaffold
            else:

                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
        logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        

        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.EVAL:
            def metric_fn(label_ids, logits,num_labels,mask):
                predictions = tf.math.argmax(logits, axis=-1, output_type=tf.int32)
                cm = metrics.streaming_confusion_matrix(label_ids, predictions, num_labels-1, weights=mask)
                return {
                    "confusion_matrix":cm
                }
                #
            eval_metrics = (metric_fn, [label_ids, logits, num_labels, mask])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predicts, scaffold_fn=scaffold_fn
            )
        return output_spec
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        is_real_example, label_ids = None, None
        if FLAGS.export_dir is None:
            label_ids = features["label_ids"]
            if "is_real_example" in features:
                is_real_example = tf.cast(features["is_real_example"],
                                          dtype=tf.float32)
            else:
                is_real_example = tf.ones(tf.shape(label_ids),
                                          dtype=tf.float32)

        is_training = (mode == tf_estimator.ModeKeys.TRAIN)

        (total_loss, per_example_loss, logits,
         probabilities) = create_model(bert_config, is_training, input_ids,
                                       input_mask, segment_ids, label_ids,
                                       num_labels, use_one_hot_embeddings)

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf_estimator.ModeKeys.TRAIN:

            train_op = optimization.create_optimizer(total_loss, learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps, use_tpu)

            output_spec = tf_estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn)
        elif mode == tf_estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, label_ids, logits,
                          is_real_example):
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                accuracy = tf.metrics.accuracy(labels=label_ids,
                                               predictions=predictions,
                                               weights=is_real_example)
                loss = tf.metrics.mean(values=per_example_loss,
                                       weights=is_real_example)
                return {
                    "eval_accuracy": accuracy,
                    "eval_loss": loss,
                }

            eval_metrics = (metric_fn, [
                per_example_loss, label_ids, logits, is_real_example
            ])
            output_spec = tf_estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            probabilities = tf.identity(probabilities, name="probabilities")
            output_spec = tf_estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                predictions={"probabilities": probabilities},
                scaffold_fn=scaffold_fn)
        return output_spec
示例#4
0
def convert_tf_model(model_dir, save_dir, test_conversion, gpu, mobilebert_dir):
    ctx = mx.gpu(gpu) if gpu is not None else mx.cpu()
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    cfg, json_cfg_path, vocab_path = convert_tf_assets(model_dir)
    with open(os.path.join(save_dir, 'model.yml'), 'w') as of:
        of.write(cfg.dump())
    new_vocab = HuggingFaceWordPieceTokenizer(
        vocab_file=vocab_path,
        unk_token='[UNK]',
        pad_token='[PAD]',
        cls_token='[CLS]',
        sep_token='[SEP]',
        mask_token='[MASK]',
        lowercase=True).vocab
    new_vocab.save(os.path.join(save_dir, 'vocab.json'))

    # test input data
    batch_size = 3
    seq_length = 32
    num_mask = 5
    input_ids = np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, seq_length))
    valid_length = np.random.randint(seq_length // 2, seq_length, (batch_size,))
    input_mask = np.broadcast_to(np.arange(seq_length).reshape(1, -1), (batch_size, seq_length)) \
        < np.expand_dims(valid_length, 1)
    segment_ids = np.random.randint(0, 2, (batch_size, seq_length))
    mlm_positions = np.random.randint(0, seq_length // 2, (batch_size, num_mask))

    tf_input_ids = tf.constant(input_ids, dtype=np.int32)
    tf_input_mask = tf.constant(input_mask, dtype=np.int32)
    tf_segment_ids = tf.constant(segment_ids, dtype=np.int32)

    init_checkpoint = os.path.join(model_dir, 'mobilebert_variables.ckpt')
    tf_params = read_tf_checkpoint(init_checkpoint)
    # get parameter names for tensorflow with unused parameters filtered out.
    tf_names = sorted(tf_params.keys())
    tf_names = filter(lambda name: not name.endswith('adam_m'), tf_names)
    tf_names = filter(lambda name: not name.endswith('adam_v'), tf_names)
    tf_names = filter(lambda name: name != 'global_step', tf_names)
    tf_names = list(tf_names)

    sys.path.append(mobilebert_dir)
    from mobilebert import modeling

    tf_bert_config = modeling.BertConfig.from_json_file(json_cfg_path)
    bert_model = modeling.BertModel(
        config=tf_bert_config,
        is_training=False,
        input_ids=tf_input_ids,
        input_mask=tf_input_mask,
        token_type_ids=tf_segment_ids,
        use_one_hot_embeddings=False)
    tvars = tf.trainable_variables()
    assignment_map, _ = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
    tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        # the name of the parameters are ending with ':0' like 'Mobile
        # Bert/embeddings/word_embeddings:0'
        backbone_params = {v.name.split(":")[0]: v.read_value() for v in tvars}
        backbone_params = sess.run(backbone_params)
        tf_token_outputs_np = {
            'pooled_output': sess.run(bert_model.get_pooled_output()),
            'sequence_output': sess.run(bert_model.get_sequence_output()),
        }

    # The following part only ensure the parameters in backbone model are valid
    for k in backbone_params:
        assert_allclose(tf_params[k], backbone_params[k])

    # Build gluon model and initialize
    gluon_pretrain_model = MobileBertForPretrain(cfg)
    gluon_pretrain_model.initialize(ctx=ctx)
    gluon_pretrain_model.hybridize()

    # pepare test data
    mx_input_ids = mx.np.array(input_ids, dtype=np.int32, ctx=ctx)
    mx_valid_length = mx.np.array(valid_length, dtype=np.int32, ctx=ctx)
    mx_token_types = mx.np.array(segment_ids, dtype=np.int32, ctx=ctx)
    mx_masked_positions = mx.np.array(mlm_positions, dtype=np.int32, ctx=ctx)

    has_mlm = True
    name_map = get_name_map(tf_names, cfg.MODEL.num_stacked_ffn)
    # go through the gluon model to infer the shape of parameters
    model = gluon_pretrain_model
    contextual_embedding, pooled_output, nsp_score, mlm_scores = \
        model(mx_input_ids, mx_token_types, mx_valid_length, mx_masked_positions)
    # replace tensorflow parameter names with gluon parameter names
    mx_params = model.collect_params()
    all_keys = set(mx_params.keys())
    for (src_name, dst_name) in name_map.items():
        tf_param_val = tf_params[src_name]
        if dst_name is None:
            continue
        all_keys.remove(dst_name)
        if src_name.endswith('kernel'):
            mx_params[dst_name].set_data(tf_param_val.T)
        else:
            mx_params[dst_name].set_data(tf_param_val)

    if has_mlm:
        # 'embedding_table.weight' is shared with word_embed.weight
        all_keys.remove('embedding_table.weight')
    assert len(all_keys) == 0, 'parameters missing from tensorflow checkpoint'

    # test conversion results for backbone model
    if test_conversion:
        tf_contextual_embedding = tf_token_outputs_np['sequence_output']
        tf_pooled_output = tf_token_outputs_np['pooled_output']
        contextual_embedding, pooled_output = model.backbone_model(
            mx_input_ids, mx_token_types, mx_valid_length)
        assert_allclose(pooled_output.asnumpy(), tf_pooled_output, 1E-2, 1E-2)
        for i in range(batch_size):
            ele_valid_length = valid_length[i]
            assert_allclose(contextual_embedding[i, :ele_valid_length, :].asnumpy(),
                            tf_contextual_embedding[i, :ele_valid_length, :], 1E-2, 1E-2)
    model.backbone_model.save_parameters(os.path.join(save_dir, 'model.params'), deduplicate=True)
    logging.info('Convert the backbone model in {} to {}/{}'.format(model_dir, save_dir, 'model.params'))
    model.save_parameters(os.path.join(save_dir, 'model_mlm.params'), deduplicate=True)
    logging.info('Convert the MLM and NSP model in {} to {}/{}'.format(model_dir,
                                                                       save_dir, 'model_mlm.params'))

    logging.info('Conversion finished!')
    logging.info('Statistics:')

    old_names = os.listdir(save_dir)
    for old_name in old_names:
        new_name, long_hash = naming_convention(save_dir, old_name)
        old_path = os.path.join(save_dir, old_name)
        new_path = os.path.join(save_dir, new_name)
        shutil.move(old_path, new_path)
        file_size = os.path.getsize(new_path)
        logging.info('\t{}/{} {} {}'.format(save_dir, new_name, long_hash, file_size))