Esempio n. 1
0
    def bert_module_fn(is_training):
        """Spec function for a token embedding module."""

        input_ids = tf.placeholder(shape=[None, None],
                                   dtype=tf.int32,
                                   name="input_ids")
        input_mask = tf.placeholder(shape=[None, None],
                                    dtype=tf.int32,
                                    name="input_mask")
        token_type = tf.placeholder(shape=[None, None],
                                    dtype=tf.int32,
                                    name="segment_ids")

        config = BertConfig.from_json_file(config_path)
        model = BertModel(config=config,
                          is_training=is_training,
                          input_ids=input_ids,
                          input_mask=input_mask,
                          token_type_ids=token_type)

        model.input_to_output()
        seq_output = model.get_all_encoder_layers()[-1]

        config_file = tf.constant(value=config_path,
                                  dtype=tf.string,
                                  name="config_file")
        vocab_file = tf.constant(value=vocab_path,
                                 dtype=tf.string,
                                 name="vocab_file")
        lower_case = tf.constant(do_lower_case)

        tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, config_file)
        tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_file)

        input_map = {
            "input_ids": input_ids,
            "input_mask": input_mask,
            "segment_ids": token_type
        }

        output_map = {"sequence_output": seq_output}

        output_info_map = {
            "vocab_file": vocab_file,
            "do_lower_case": lower_case
        }

        hub.add_signature(name="tokens", inputs=input_map, outputs=output_map)
        hub.add_signature(name="tokenization_info",
                          inputs={},
                          outputs=output_info_map)
    def model_fn(features, labels, mode, params):
        print('*** Features ***')
        for name in sorted(features.keys()):
            tf.logging.info(' name = {}, shape = {}'.format(
                name, features[name].shape))

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        input_x = features['input_x']
        input_mask = features['input_mask']

        model = BertModel(config, is_training, input_x, input_mask)

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        if init_checkpoint:
            (assignment_map,
             initialized_variable_names) = get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        encoder_outputs = model.get_all_encoder_layers()
        last_encoder_outputs = encoder_outputs[-1]
        prev_encoder_outputs = encoder_outputs[:-1]

        last_outputs = teacher_classfier(last_encoder_outputs,
                                         config.num_classes, 'teacher')

        prev_outputs = {}
        for layer_id, layer_out in enumerate(prev_encoder_outputs):
            prev_outputs[layer_id] = student_classifier(
                layer_out, config.num_classes, layer_id)

        if mode == tf.estimator.ModeKeys.PREDICT:
            pass
        else:
            if mode == tf.estimator.ModeKeys.TRAIN:
                last_probs = last_outputs[1]
                kl_loss = 0
                for _, value in prev_outputs.items():
                    prev_probs = value[1]
                    kl_loss += kl_divergence(last_probs, prev_probs)

                learning_rate = tf.train.polynomial_decay(
                    config.LEARNING_RATE,
                    tf.train.get_or_create_global_step(),
                    config.TRAIN_STEPS,
                    end_learning_rate=config.LEARNING_LIMIT,
                    power=1.0,
                    cycle=False)
                lr = tf.maximum(tf.constant(config.LEARNING_LIMIT),
                                learning_rate)

                optimizer = tf.train.AdamOptimizer(lr, name='optimizer')
                updating_tvars = [v for v in tvars if 'bert' not in v.name]
                gradients = tf.gradients(kl_loss,
                                         updating_tvars,
                                         colocate_gradients_with_ops=True)
                clipped_gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
                train_op = optimizer.apply_gradients(
                    zip(clipped_gradients, tvars),
                    global_step=tf.train.get_global_step())

                logging_hook = tf.train.LoggingTensorHook(
                    {
                        'step': tf.train.get_global_step(),
                        'kl_loss': kl_loss,
                        'lr': learning_rate
                    },
                    every_n_iter=2)
                output_spec = tf.estimator.EstimatorSpec(
                    mode,
                    loss=kl_loss,
                    train_op=train_op,
                    training_hooks=[logging_hook])