コード例 #1
0
def main(_):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    processors = {
        "cola": ColaProcessor,
        "mnlim": MnliMProcessor,
        "mnlimm": MnliMMProcessor,
        "mrpc": MrpcProcessor,
        "qnli": QnliProcessor,
        "qqp": QqpProcessor,
        "rte": RteProcessor,
        "sst2": Sst2Processor,
        "stsb": StsbProcessor,
        "wnli": WnliProcessor,
        "ax": AxProcessor,
        "mnlimdevastest": MnliMDevAsTestProcessor
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    if not FLAGS.do_eval:
        raise ValueError("At least 'do_eval' must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.io.gfile.makedirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()
    print("Current task", task_name)

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    # special handling for mnlimdevastest
    if task_name == 'mnlimdevastest':
        task_name = 'mnlim'

    label_list = processor.get_labels()
    print("Label list of current task", label_list)

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    eval_examples = processor.get_dev_examples(FLAGS.data_dir)
    num_actual_eval_examples = len(eval_examples)
    print("num_actual_eval_examples", num_actual_eval_examples)

    batch_size = FLAGS.eval_batch_size
    embed_dim = FLAGS.hidden_size  # hidden size, 768 for BERT-base, 512 for BERT-small
    seq_length = FLAGS.max_seq_length
    num_labels = len(label_list)

    # Define some placeholders for the input
    input_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                            shape=[None, seq_length],
                                            name='input_ids')
    input_mask_ph = tf.compat.v1.placeholder(tf.int32,
                                             shape=[None, seq_length],
                                             name='input_mask')
    segment_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                              shape=[None, seq_length],
                                              name='segment_ids')
    label_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                            shape=[
                                                None,
                                            ],
                                            name='label_ids')

    tf.compat.v1.logging.info(
        "Running single-head masking out and direct evaluation!")

    # we want to mask out the individual head and then evaluate. So there are 12 layers * 12 heads results.
    n_layers = 12
    n_heads = 12
    folder = FLAGS.output_dir
    save_file = 'single_head_mask.pickle'
    output = np.zeros((n_layers, n_heads))

    # two placeholders for the head coordinates, layer, head
    head_mask_ph = tf.compat.v1.placeholder(tf.int32,
                                            shape=[
                                                None,
                                            ],
                                            name='head_mask')
    layer_mask_ph = tf.compat.v1.placeholder(tf.int32,
                                             shape=[
                                                 None,
                                             ],
                                             name='layer_mask')

    model = modeling.BertModel(
        config=bert_config,
        is_training=False,
        input_ids=input_ids_ph,  # input_ids,
        input_mask=input_mask_ph,  # input_mask,
        token_type_ids=segment_ids_ph,  # segment_ids,
        use_one_hot_embeddings=False,
        head_mask=head_mask_ph,
        layer_mask=layer_mask_ph)

    output_layer = model.get_pooled_output()
    output_weights = tf.get_variable(
        "output_weights", [num_labels, embed_dim],
        initializer=tf.truncated_normal_initializer(stddev=0.02))
    output_bias = tf.get_variable("output_bias", [num_labels],
                                  initializer=tf.zeros_initializer())
    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)

    with tf.compat.v1.variable_scope("loss"):
        # for stsb
        if num_labels == 1:
            logits = tf.squeeze(logits, [-1])
            per_example_loss = tf.square(logits - label_ids_ph)
            loss = tf.reduce_mean(per_example_loss)
        else:
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids_ph,
                                        depth=num_labels,
                                        dtype=tf.float32)
            per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                              axis=-1)
            loss = tf.reduce_mean(per_example_loss)
            predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)

    # metric and summary
    # metric is tf.metric object, (val, op)
    metric = metric_fn(per_example_loss, label_ids_ph, logits, num_labels,
                       task_name)
    metric_name = list(metric.keys())
    metric_val = [m[0] for m in metric.values()]
    metric_op = [m[1] for m in metric.values()]

    init_checkpoint = FLAGS.init_checkpoint
    tvars = tf.compat.v1.trainable_variables()
    saver_init = tf.train.Saver(tvars)

    # Isolate the variables stored behind the scenes by the metric operation
    var_metric = []
    for key in metric.keys():
        var_metric.extend(
            tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=key))
    # Define initializer to initialize/reset running variables
    metric_vars_initializer = tf.variables_initializer(var_list=var_metric)

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.compat.v1.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver_init.restore(sess, init_checkpoint)

        # if number of eval examples < 1000, just load it directly, or load by batch.
        if num_actual_eval_examples <= 1000:
            eval_input_ids, eval_input_mask, eval_segment_ids, \
            eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=num_actual_eval_examples,
                                                                     seq_length=seq_length,
                                                                     examples=eval_examples,
                                                                     label_list=label_list,
                                                                     tokenizer=tokenizer)

        # loop over layers, then loop over heads
        for l in range(n_layers):
            for h in range(n_heads):

                cur_l, cur_h = l, h
                head_mask = [h]
                layer_mask = [l]

                # if number of eval examples < 1000, just load it directly, or load by batch.
                if num_actual_eval_examples <= 1000:
                    sess.run(metric_vars_initializer)
                    sess.run(metric_op,
                             feed_dict={
                                 input_ids_ph: eval_input_ids,
                                 input_mask_ph: eval_input_mask,
                                 segment_ids_ph: eval_segment_ids,
                                 label_ids_ph: eval_label_ids,
                                 head_mask_ph: head_mask,
                                 layer_mask_ph: layer_mask
                             })
                    eval_metric_val = sess.run(metric_val)
                else:
                    num_batch_eval = num_actual_eval_examples // batch_size \
                        if num_actual_eval_examples % batch_size == 0 \
                        else num_actual_eval_examples // batch_size + 1
                    id_eval = 0
                    sess.run(metric_vars_initializer)
                    for _ in range(num_batch_eval):
                        eval_input_ids, eval_input_mask, eval_segment_ids, \
                        eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size,
                                                                                 seq_length=seq_length,
                                                                                 examples=eval_examples,
                                                                                 label_list=label_list,
                                                                                 tokenizer=tokenizer,
                                                                                 train_idx_offset=id_eval)
                        id_eval += batch_size
                        sess.run(metric_op,
                                 feed_dict={
                                     input_ids_ph: eval_input_ids,
                                     input_mask_ph: eval_input_mask,
                                     segment_ids_ph: eval_segment_ids,
                                     label_ids_ph: eval_label_ids,
                                     head_mask_ph: head_mask,
                                     layer_mask_ph: layer_mask
                                 })
                    eval_metric_val = sess.run(metric_val)

                for name, val in zip(metric_name, eval_metric_val):
                    if name == 'accuracy':
                        output[cur_l][cur_h] = val
                        print(
                            "Mask out the head in (Layer {}, Head {}) | {}: {}"
                            .format(cur_l, cur_h, name, val))

        joblib.dump(output, folder + save_file)
コード例 #2
0
def main(_):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    processors = {
        "cola": ColaProcessor,
        "mnlim": MnliMProcessor,
        "mnlimm": MnliMMProcessor,
        "mrpc": MrpcProcessor,
        "qnli": QnliProcessor,
        "qqp": QqpProcessor,
        "rte": RteProcessor,
        "sst2": Sst2Processor,
        "stsb": StsbProcessor,
        "wnli": WnliProcessor,
        "ax": AxProcessor,
        "mnlimdevastest": MnliMDevAsTestProcessor
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    if not FLAGS.do_train and \
       not FLAGS.do_eval and \
       not FLAGS.do_pred:
        raise ValueError(
            "At least one of 'do_train', 'do_eval' or 'do_pred' must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.io.gfile.makedirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()
    print("Current task", task_name)

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    # special handling for mnlimdevastest
    if task_name == 'mnlimdevastest':
        task_name = 'mnlim'

    label_list = processor.get_labels()
    print("Label list of current task", label_list)

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    train_examples = processor.get_train_examples(FLAGS.data_dir)
    eval_examples = processor.get_dev_examples(FLAGS.data_dir)
    num_actual_train_examples = len(train_examples)
    num_actual_eval_examples = len(eval_examples)
    print("num_actual_train_examples", num_actual_train_examples)
    print("num_actual_eval_examples", num_actual_eval_examples)
    if FLAGS.do_pred:
        test_examples = processor.get_test_examples(FLAGS.data_dir)
        num_actual_test_examples = len(test_examples)
        print("num_actual_test_examples", num_actual_test_examples)

    batch_size = FLAGS.train_batch_size
    epochs = FLAGS.num_train_epochs
    embed_dim = FLAGS.hidden_size  # hidden size, 768 for BERT-base, 512 for BERT-small
    seq_length = FLAGS.max_seq_length
    num_labels = len(label_list)

    # Define some placeholders for the input
    input_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                            shape=[None, seq_length],
                                            name='input_ids')
    input_mask_ph = tf.compat.v1.placeholder(tf.int32,
                                             shape=[None, seq_length],
                                             name='input_mask')
    segment_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                              shape=[None, seq_length],
                                              name='segment_ids')
    label_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                            shape=[
                                                None,
                                            ],
                                            name='label_ids')

    tf.compat.v1.logging.info("Running swapping layers!")

    num_train_steps = num_actual_train_examples // batch_size * epochs
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    # Swapped layers
    layer_folder_name = FLAGS.layers
    if FLAGS.layers is None:
        raise ValueError("In swapping experiments, layers must not be None. ")
    layers = list(map(int, FLAGS.layers.split(',')))
    layer_order = list(range(12)) if embed_dim == 768 else list(range(6))
    layer1, layer2 = layers[0], layers[1]
    if len(layers) == 2:
        layers.sort()
        if layer1 != layer2:
            layer_order[layer1], layer_order[layer2] = layer_order[
                layer2], layer_order[layer1]
        else:
            raise ValueError("Two layers should be different! ")
    else:
        layer_order = layer_order[:layers[0]] + layers[::-1] + layer_order[
            layers[-1] + 1:]
    print("Current layer order: ", layer_order)

    # this placeholder is to control the flag for the dropout
    keep_prob_ph = tf.compat.v1.placeholder(tf.float32, name="keep_prob")
    is_training_ph = tf.compat.v1.placeholder(tf.bool, name='is_training')

    model = modeling.BertModel(
        config=bert_config,
        is_training=is_training_ph,
        input_ids=input_ids_ph,  # input_ids,
        input_mask=input_mask_ph,  # input_mask,
        token_type_ids=segment_ids_ph,  # segment_ids,
        use_one_hot_embeddings=False,
        layer_order=layer_order,
        use_estimator=False)

    output_layer = model.get_pooled_output()
    output_layer = tf.nn.dropout(output_layer, keep_prob=keep_prob_ph)
    output_weights = tf.get_variable(
        "output_weights", [num_labels, embed_dim],
        initializer=tf.truncated_normal_initializer(stddev=0.02))
    output_bias = tf.get_variable("output_bias", [num_labels],
                                  initializer=tf.zeros_initializer())
    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)

    with tf.compat.v1.variable_scope("loss"):
        # for stsb
        if num_labels == 1:
            logits = tf.squeeze(logits, [-1])
            per_example_loss = tf.square(logits - label_ids_ph)
            loss = tf.reduce_mean(per_example_loss)
        else:
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids_ph,
                                        depth=num_labels,
                                        dtype=tf.float32)
            per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                              axis=-1)
            loss = tf.reduce_mean(per_example_loss)
            predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)

    # metric and summary
    # metric is tf.metric object, (val, op)
    metric = metric_fn(per_example_loss, label_ids_ph, logits, num_labels,
                       task_name)
    metric_name = list(metric.keys())
    metric_val = [m[0] for m in metric.values()]
    metric_op = [m[1] for m in metric.values()]

    metric_phs = [
        tf.compat.v1.placeholder(tf.float32, name="{}_ph".format(key))
        for key in metric.keys()
    ]
    summaries = [
        tf.compat.v1.summary.scalar(key, metric_phs[i])
        for i, key in enumerate(metric.keys())
    ]
    train_summary_total = tf.summary.merge(summaries)
    eval_summary_total = tf.summary.merge(summaries)

    log_dir = FLAGS.output_dir + 'layer_{}/'.format(layer_folder_name)

    init_checkpoint = FLAGS.init_checkpoint
    tvars = tf.compat.v1.trainable_variables()
    var_init = [
        v for v in tvars
        if 'output_weights' not in v.name and 'output_bias' not in v.name
    ]
    var_output = [
        v for v in tvars
        if 'output_weights' in v.name or "output_bias" in v.name
    ]

    if not FLAGS.load_from_finetuned:
        # Init from Model0
        saver_init = tf.train.Saver(var_init)
    else:
        # Init from Model1
        saver_init = tf.train.Saver(var_init + var_output)

    var_train = var_init + var_output
    print("Training parameters")
    for v in var_train:
        print(v)

    train_op = optimization.create_optimizer(loss=loss,
                                             init_lr=FLAGS.learning_rate,
                                             num_train_steps=num_train_steps,
                                             num_warmup_steps=num_warmup_steps,
                                             use_tpu=False,
                                             tvars=var_train)

    saver_all = tf.train.Saver(var_list=var_train, max_to_keep=1)

    # Isolate the variables stored behind the scenes by the metric operation
    var_metric = []
    for key in metric.keys():
        var_metric.extend(
            tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=key))
    # Define initializer to initialize/reset running variables
    metric_vars_initializer = tf.variables_initializer(var_list=var_metric)

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.compat.v1.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver_init.restore(sess, init_checkpoint)

        writer = tf.compat.v1.summary.FileWriter(log_dir + 'log/train/',
                                                 sess.graph)
        writer_eval = tf.compat.v1.summary.FileWriter(log_dir + 'log/eval/')

        # if number of eval examples < 1000, just load it directly, or load by batch.
        if num_actual_eval_examples <= 1000:
            eval_input_ids, eval_input_mask, eval_segment_ids, \
            eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=num_actual_eval_examples,
                                                                     seq_length=seq_length,
                                                                     examples=eval_examples,
                                                                     label_list=label_list,
                                                                     tokenizer=tokenizer)

        start_metric = {"eval_{}".format(key): 0 for key in metric_name}
        if FLAGS.do_train:
            tf.logging.info("***** Run training *****")
            step = 1
            for n in range(epochs):

                np.random.shuffle(train_examples)
                num_batch = num_actual_train_examples // batch_size if num_actual_train_examples % batch_size == 0 \
                    else num_actual_train_examples // batch_size + 1
                id = 0

                for b in range(num_batch):

                    input_ids, input_mask, \
                    segment_ids, label_ids, is_real_example = generate_ph_input(batch_size=batch_size,
                                                                                seq_length=seq_length,
                                                                                examples=train_examples,
                                                                                label_list=label_list,
                                                                                tokenizer=tokenizer,
                                                                                train_idx_offset=id)
                    id += batch_size

                    sess.run(metric_vars_initializer)
                    sess.run([train_op] + metric_op,
                             feed_dict={
                                 input_ids_ph: input_ids,
                                 input_mask_ph: input_mask,
                                 segment_ids_ph: segment_ids,
                                 label_ids_ph: label_ids,
                                 is_training_ph: True,
                                 keep_prob_ph: 0.9
                             })
                    train_metric_val = sess.run(metric_val)
                    train_summary_str = sess.run(
                        train_summary_total,
                        feed_dict={
                            ph: value
                            for ph, value in zip(metric_phs, train_metric_val)
                        })
                    writer.add_summary(train_summary_str, step)

                    if step % 100 == 0 or step % num_batch == 0 or step == 1:
                        # evaluate on dev set

                        if num_actual_eval_examples <= 1000:

                            sess.run(metric_vars_initializer)
                            sess.run(metric_op,
                                     feed_dict={
                                         input_ids_ph: eval_input_ids,
                                         input_mask_ph: eval_input_mask,
                                         segment_ids_ph: eval_segment_ids,
                                         label_ids_ph: eval_label_ids,
                                         is_training_ph: False,
                                         keep_prob_ph: 1
                                     })
                            eval_metric_val = sess.run(metric_val)
                            eval_summary_str = sess.run(
                                eval_summary_total,
                                feed_dict={
                                    ph: value
                                    for ph, value in zip(
                                        metric_phs, eval_metric_val)
                                })

                        else:
                            num_batch_eval = num_actual_eval_examples // batch_size \
                                if num_actual_eval_examples % batch_size == 0 \
                                else num_actual_eval_examples // batch_size + 1
                            id_eval = 0

                            sess.run(metric_vars_initializer)
                            for _ in range(num_batch_eval):
                                eval_input_ids, eval_input_mask, eval_segment_ids, \
                                eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size,
                                                                                         seq_length=seq_length,
                                                                                         examples=eval_examples,
                                                                                         label_list=label_list,
                                                                                         tokenizer=tokenizer,
                                                                                         train_idx_offset=id_eval)
                                id_eval += batch_size
                                sess.run(metric_op,
                                         feed_dict={
                                             input_ids_ph: eval_input_ids,
                                             input_mask_ph: eval_input_mask,
                                             segment_ids_ph: eval_segment_ids,
                                             label_ids_ph: eval_label_ids,
                                             is_training_ph: False,
                                             keep_prob_ph: 1
                                         })

                            eval_metric_val = sess.run(metric_val)
                            eval_summary_str = sess.run(
                                eval_summary_total,
                                feed_dict={
                                    ph: value
                                    for ph, value in zip(
                                        metric_phs, eval_metric_val)
                                })

                        writer_eval.add_summary(eval_summary_str, step)

                        if step == 1:
                            for key, val in zip(metric_name, eval_metric_val):
                                start_metric["eval_{}".format(key)] = val

                    if step % 100 == 0 or step % num_batch == 0 or step == 1:
                        train_metric_list = []
                        for i in range(len(train_metric_val)):
                            if metric_name[i] == 'loss':
                                train_metric_list.append(
                                    "{}: %2.4f".format(metric_name[i]) %
                                    train_metric_val[i])
                            else:
                                train_metric_list.append(
                                    "{}: %.4f".format(metric_name[i]) %
                                    train_metric_val[i])
                        train_str = 'Train ' + '|'.join(train_metric_list)

                        eval_metric_list = []
                        for i in range(len(eval_metric_val)):
                            if metric_name[i] == 'loss':
                                eval_metric_list.append(
                                    "{}: %2.4f".format(metric_name[i]) %
                                    eval_metric_val[i])
                            else:
                                eval_metric_list.append(
                                    "{}: %.4f".format(metric_name[i]) %
                                    eval_metric_val[i])
                        eval_str = 'Eval ' + '|'.join(eval_metric_list)

                        print(
                            "Swap layer order {} | Epoch: %4d/%4d | Batch: %4d/%4d | {} | {}"
                            .format(layer_folder_name, train_str, eval_str) %
                            (n, epochs, b, num_batch))

                    if step % num_batch == 0:
                        saver_all.save(sess,
                                       log_dir +
                                       'swap_{}'.format(layer_folder_name),
                                       global_step=step)

                    step += 1

            writer.close()
            writer_eval.close()

        end_metric = {"eval_{}".format(key): 0 for key in metric_name}
        if FLAGS.do_eval:
            tf.logging.info("***** Run evaluation *****")
            if num_actual_eval_examples <= 1000:

                sess.run(metric_vars_initializer)
                sess.run(metric_op,
                         feed_dict={
                             input_ids_ph: eval_input_ids,
                             input_mask_ph: eval_input_mask,
                             segment_ids_ph: eval_segment_ids,
                             label_ids_ph: eval_label_ids,
                             is_training_ph: False,
                             keep_prob_ph: 1
                         })
                eval_metric_val = sess.run(metric_val)
                preds = sess.run(predictions,
                                 feed_dict={
                                     input_ids_ph: eval_input_ids,
                                     input_mask_ph: eval_input_mask,
                                     segment_ids_ph: eval_segment_ids,
                                     label_ids_ph: eval_label_ids,
                                     is_training_ph: False,
                                     keep_prob_ph: 1
                                 })
                eval_label_ids_lst = eval_label_ids
            else:
                num_batch_eval = num_actual_eval_examples // batch_size \
                    if num_actual_eval_examples % batch_size == 0 \
                    else num_actual_eval_examples // batch_size + 1
                id_eval = 0
                preds = np.zeros(num_actual_eval_examples)
                eval_label_ids_lst = np.zeros(num_actual_eval_examples)

                sess.run(metric_vars_initializer)
                for i in range(num_batch_eval):
                    eval_input_ids, eval_input_mask, eval_segment_ids, \
                    eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size,
                                                                             seq_length=seq_length,
                                                                             examples=eval_examples,
                                                                             label_list=label_list,
                                                                             tokenizer=tokenizer,
                                                                             train_idx_offset=id_eval)
                    id_eval += batch_size
                    sess.run(metric_op,
                             feed_dict={
                                 input_ids_ph: eval_input_ids,
                                 input_mask_ph: eval_input_mask,
                                 segment_ids_ph: eval_segment_ids,
                                 label_ids_ph: eval_label_ids,
                                 is_training_ph: False,
                                 keep_prob_ph: 1
                             })
                    pred = sess.run(predictions,
                                    feed_dict={
                                        input_ids_ph: eval_input_ids,
                                        input_mask_ph: eval_input_mask,
                                        segment_ids_ph: eval_segment_ids,
                                        label_ids_ph: eval_label_ids,
                                        is_training_ph: False,
                                        keep_prob_ph: 1
                                    })
                    preds[i * batch_size:min(id_eval, num_actual_eval_examples
                                             )] = pred[:]
                    eval_label_ids_lst[i * batch_size:min(
                        id_eval, num_actual_eval_examples)] = eval_label_ids[:]
                eval_metric_val = sess.run(metric_val)

            for key, val in zip(metric_name, eval_metric_val):
                end_metric["eval_{}".format(key)] = val

            output_predict_file = os.path.join(log_dir, 'dev_predictions.tsv')
            writer_output = tf.io.gfile.GFile(output_predict_file, "w")
            preds = preds.astype(int)
            eval_label_ids_lst = eval_label_ids_lst.astype(int)

            num_written_lines = 0
            if task_name != 'stsb':
                writer_output.write(
                    "ID \t Label ID \t Label \t Ground Truth ID \t Ground Truth \n"
                )
            else:
                writer_output.write("ID \t Label \n")
            for (i, pred) in enumerate(preds):
                if task_name != 'stsb':
                    writer_output.write("{} \t {} \t {} \t {} \t {} \n".format(
                        i, pred, label_list[pred], eval_label_ids_lst[i],
                        label_list[eval_label_ids_lst[i]]))
                else:
                    writer_output.write("{} \t {} \n".format(
                        num_written_lines, pred))
            writer_output.close()
            tf.logging.info("***** Finish writing *****")

        print("Start metric", start_metric)
        print("End metric", end_metric)

        test_metric = {"test_{}".format(key): 0 for key in metric_name}
        if FLAGS.do_pred:
            # if number of test examples < 1000, just load it directly, or load by batch.
            # prediction
            tf.logging.info("***** Predict results *****")
            if num_actual_test_examples <= 1000:
                test_input_ids, test_input_mask, test_segment_ids, \
                test_label_ids, test_is_real_example = generate_ph_input(batch_size=num_actual_test_examples,
                                                                         seq_length=seq_length,
                                                                         examples=test_examples,
                                                                         label_list=label_list,
                                                                         tokenizer=tokenizer)
                sess.run(metric_vars_initializer)
                sess.run(metric_op,
                         feed_dict={
                             input_ids_ph: test_input_ids,
                             input_mask_ph: test_input_mask,
                             segment_ids_ph: test_segment_ids,
                             label_ids_ph: test_label_ids,
                             is_training_ph: False,
                             keep_prob_ph: 1
                         })
                test_metric_val = sess.run(metric_val)
                preds = sess.run(predictions,
                                 feed_dict={
                                     input_ids_ph: test_input_ids,
                                     input_mask_ph: test_input_mask,
                                     segment_ids_ph: test_segment_ids,
                                     label_ids_ph: test_label_ids,
                                     is_training_ph: False,
                                     keep_prob_ph: 1
                                 })
                test_label_ids_lst = test_label_ids
            else:
                num_batch_test = num_actual_test_examples // batch_size \
                    if num_actual_test_examples % batch_size == 0 \
                    else num_actual_test_examples // batch_size + 1
                id_test = 0

                preds = np.zeros(num_actual_test_examples)
                test_label_ids_lst = np.zeros(num_actual_test_examples)
                sess.run(metric_vars_initializer)
                for i in range(num_batch_test):
                    test_input_ids, test_input_mask, test_segment_ids, \
                    test_label_ids, test_is_real_example = generate_ph_input(batch_size=batch_size,
                                                                             seq_length=seq_length,
                                                                             examples=test_examples,
                                                                             label_list=label_list,
                                                                             tokenizer=tokenizer,
                                                                             train_idx_offset=id_test)
                    id_test += batch_size
                    sess.run(metric_op,
                             feed_dict={
                                 input_ids_ph: test_input_ids,
                                 input_mask_ph: test_input_mask,
                                 segment_ids_ph: test_segment_ids,
                                 label_ids_ph: test_label_ids,
                                 is_training_ph: False,
                                 keep_prob_ph: 1
                             })
                    pred = sess.run(predictions,
                                    feed_dict={
                                        input_ids_ph: test_input_ids,
                                        input_mask_ph: test_input_mask,
                                        segment_ids_ph: test_segment_ids,
                                        label_ids_ph: test_label_ids,
                                        is_training_ph: False,
                                        keep_prob_ph: 1
                                    })
                    preds[i * batch_size:min(id_test, num_actual_test_examples
                                             )] = pred[:]
                    test_label_ids_lst[i * batch_size:min(
                        id_test, num_actual_test_examples)] = test_label_ids[:]
                test_metric_val = sess.run(metric_val)
            for key, val in zip(metric_name, test_metric_val):
                test_metric["test_{}".format(key)] = val

            output_predict_file = os.path.join(log_dir, 'test_predictions.tsv')
            submit_predict_file = os.path.join(
                log_dir, "{}.tsv".format(standard_file_name[task_name]))
            writer_output = tf.io.gfile.GFile(output_predict_file, "w")
            writer_submit = tf.io.gfile.GFile(submit_predict_file, 'w')
            preds = preds.astype(int)
            test_label_ids_lst = test_label_ids_lst.astype(int)

            num_written_lines = 0
            if task_name != 'stsb':
                writer_output.write(
                    "ID \t Label ID \t Label \t Ground Truth ID \t Ground Truth \n"
                )
            else:
                writer_output.write("ID \t Label \n")
            writer_submit.write("ID \t Label \n")
            for (i, pred) in enumerate(preds):
                if task_name != 'stsb':
                    writer_output.write("{} \t {} \t {} \t {} \t {} \n".format(
                        i, pred, label_list[pred], test_label_ids_lst[i],
                        label_list[test_label_ids_lst[i]]))
                    writer_submit.write("{} \t {} \n".format(
                        i, label_list[pred]))
                else:
                    writer_output.write("{} \t {} \n".format(
                        num_written_lines, pred))
                    writer_submit.write("{} \t {} \n".format(i, pred))
            writer_output.close()
            writer_submit.close()
            tf.logging.info("***** Finish writing *****")

        with tf.io.gfile.GFile(FLAGS.output_dir + 'results.txt',
                               'a') as writer:
            eval_start, eval_end, test_end = [], [], []
            for metric in metric_name:
                if metric != 'loss':
                    eval_start.append("{}: %.4f".format(metric) %
                                      start_metric["eval_{}".format(metric)])
                    eval_end.append("{}: %.4f".format(metric) %
                                    end_metric["eval_{}".format(metric)])
                    test_end.append("{}: %.4f".format(metric) %
                                    test_metric["test_{}".format(metric)])

            writer.write(
                "Swap layer order {}: Dev start: {} | Dev end: {} | Test end: {}\n"
                .format(layer_folder_name, ','.join(eval_start),
                        ','.join(eval_end), ','.join(test_end)))
コード例 #3
0
def main(_):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    processors = {
        "cola": ColaProcessor,
        "mnlim": MnliMProcessor,
        "mnlimm": MnliMMProcessor,
        "mrpc": MrpcProcessor,
        "qnli": QnliProcessor,
        "qqp": QqpProcessor,
        "rte": RteProcessor,
        "sst2": Sst2Processor,
        "stsb": StsbProcessor,
        "wnli": WnliProcessor,
        "ax": AxProcessor
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    if not FLAGS.do_eval:
        raise ValueError("At least 'do_eval' must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.io.gfile.makedirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()
    print("Current task", task_name)

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    label_list = processor.get_labels()
    print("Label list of current task", label_list)

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

    train_examples = processor.get_train_examples(FLAGS.data_dir)
    eval_examples = processor.get_dev_examples(FLAGS.data_dir)
    num_actual_train_examples = len(train_examples)
    num_actual_eval_examples = len(eval_examples)
    print("num_actual_train_examples", num_actual_train_examples)
    print("num_actual_eval_examples", num_actual_eval_examples)

    batch_size = FLAGS.train_batch_size
    embed_dim = FLAGS.hidden_size  # hidden size, 768 for BERT-base, 512 for BERT-small
    seq_length = FLAGS.max_seq_length
    num_labels = len(label_list)

    # Define some placeholders for the input
    input_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_ids')
    input_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_mask')
    segment_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='segment_ids')
    label_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, ], name='label_ids')

    tf.compat.v1.logging.info("Running cumulatively mask out heads and direct evaluation!")

    # we want to mask out the individual head and then evaluate. So there are 12 layers * 12 heads results.
    n_layers = 12
    n_heads = 12
    folder = FLAGS.output_dir
    save_file = 'cumulative_heads_mask.pickle'
    output = np.zeros((n_layers, n_heads))

    importance_coordinates = []
    if FLAGS.importance_setting == 'l2_norm':
        #######################
        # 12 * 12, layer * head, importance increases from 1 to 144.
        # [[127, 63, 67, 72, 91, 93, 124, 100, 96, 134, 15, 133],
        #  [143, 60, 107, 128, 57, 106, 118, 83, 144, 135, 99, 116],
        #  [30, 111, 115, 132, 112, 94, 122, 66, 123, 40, 75, 89],
        #  [108, 140, 141, 119, 121, 62, 137, 131, 125, 70, 85, 105],
        #  [126, 120, 113, 136, 92, 79, 110, 74, 103, 84, 86, 53],
        #  [69, 87, 117, 80, 142, 114, 129, 104, 97, 18, 52, 77],
        #  [81, 88, 48, 90, 56, 50, 58, 101, 130, 64, 35, 46],
        #  [20, 41, 38, 32, 71, 59, 82, 43, 78, 55, 47, 37],
        #  [24, 95, 27, 44, 65, 12, 28, 102, 98, 23, 14, 19],
        #  [17, 16, 2, 1, 9, 13, 68, 4, 139, 7, 21, 109],
        #  [76, 26, 8, 138, 10, 29, 31, 54, 6, 36, 3, 49],
        #  [61, 51, 42, 45, 34, 5, 73, 33, 25, 22, 39, 11]]
        #######################
        # sorted head coordinate array by importance (L2 weight magnitude), (layer, head)
        # least 1 --> most 144
        importance_coordinates = [[9, 3], [9, 2], [10, 10], [9, 7], [11, 5], [10, 8], [9, 9], [10, 2], [9, 4],
                                  [10, 4], [11, 11], [8, 5], [9, 5], [8, 10], [0, 10], [9, 1], [9, 0], [5, 9],
                                  [8, 11], [7, 0], [9, 10], [11, 9], [8, 9], [8, 0], [11, 8], [10, 1], [8, 2],
                                  [8, 6], [10, 5], [2, 0], [10, 6], [7, 3], [11, 7], [11, 4], [6, 10], [10, 9],
                                  [7, 11], [7, 2], [11, 10], [2, 9], [7, 1], [11, 2], [7, 7], [8, 3], [11, 3],
                                  [6, 11], [7, 10], [6, 2], [10, 11], [6, 5], [11, 1], [5, 10], [4, 11], [10, 7],
                                  [7, 9], [6, 4], [1, 4], [6, 6], [7, 5], [1, 1], [11, 0], [3, 5], [0, 1], [6, 9],
                                  [8, 4], [2, 7], [0, 2], [9, 6], [5, 0], [3, 9], [7, 4], [0, 3], [11, 6], [4, 7],
                                  [2, 10], [10, 0], [5, 11], [7, 8], [4, 5], [5, 3], [6, 0], [7, 6], [1, 7], [4, 9],
                                  [3, 10], [4, 10], [5, 1], [6, 1], [2, 11], [6, 3], [0, 4], [4, 4], [0, 5], [2, 5],
                                  [8, 1], [0, 8], [5, 8], [8, 8], [1, 10], [0, 7], [6, 7], [8, 7], [4, 8], [5, 7],
                                  [3, 11], [1, 5], [1, 2], [3, 0], [9, 11], [4, 6], [2, 1], [2, 4], [4, 2], [5, 5],
                                  [2, 2], [1, 11], [5, 2], [1, 6], [3, 3], [4, 1], [3, 4], [2, 6], [2, 8], [0, 6],
                                  [3, 8], [4, 0], [0, 0], [1, 3], [5, 6], [6, 8], [3, 7], [2, 3], [0, 11], [0, 9],
                                  [1, 9], [4, 3], [3, 6], [10, 3], [9, 8], [3, 1], [3, 2], [5, 4], [1, 0], [1, 8]]
    elif FLAGS.importance_setting == 'per_head_score':
        #######################
        # 12 * 12, layer * head, importance increases from 1 to 144.
        # [[127, 63, 67, 72, 91, 93, 124, 100, 96, 134, 15, 133],
        #  [143, 60, 107, 128, 57, 106, 118, 83, 144, 135, 99, 116],
        #  [30, 111, 115, 132, 112, 94, 122, 66, 123, 40, 75, 89],
        #  [108, 140, 141, 119, 121, 62, 137, 131, 125, 70, 85, 105],
        #  [126, 120, 113, 136, 92, 79, 110, 74, 103, 84, 86, 53],
        #  [69, 87, 117, 80, 142, 114, 129, 104, 97, 18, 52, 77],
        #  [81, 88, 48, 90, 56, 50, 58, 101, 130, 64, 35, 46],
        #  [20, 41, 38, 32, 71, 59, 82, 43, 78, 55, 47, 37],
        #  [24, 95, 27, 44, 65, 12, 28, 102, 98, 23, 14, 19],
        #  [17, 16, 2, 1, 9, 13, 68, 4, 139, 7, 21, 109],
        #  [76, 26, 8, 138, 10, 29, 31, 54, 6, 36, 3, 49],
        #  [61, 51, 42, 45, 34, 5, 73, 33, 25, 22, 39, 11]]
        #######################
        # sorted head coordinate array by importance (L2 weight magnitude), (layer, head)
        # least 1 --> most 144
        importance_coordinates = [[9, 3], [9, 2], [10, 10], [9, 7], [11, 5], [10, 8], [9, 9], [10, 2], [9, 4],
                                  [10, 4], [11, 11], [8, 5], [9, 5], [8, 10], [0, 10], [9, 1], [9, 0], [5, 9],
                                  [8, 11], [7, 0], [9, 10], [11, 9], [8, 9], [8, 0], [11, 8], [10, 1], [8, 2],
                                  [8, 6], [10, 5], [2, 0], [10, 6], [7, 3], [11, 7], [11, 4], [6, 10], [10, 9],
                                  [7, 11], [7, 2], [11, 10], [2, 9], [7, 1], [11, 2], [7, 7], [8, 3], [11, 3],
                                  [6, 11], [7, 10], [6, 2], [10, 11], [6, 5], [11, 1], [5, 10], [4, 11], [10, 7],
                                  [7, 9], [6, 4], [1, 4], [6, 6], [7, 5], [1, 1], [11, 0], [3, 5], [0, 1], [6, 9],
                                  [8, 4], [2, 7], [0, 2], [9, 6], [5, 0], [3, 9], [7, 4], [0, 3], [11, 6], [4, 7],
                                  [2, 10], [10, 0], [5, 11], [7, 8], [4, 5], [5, 3], [6, 0], [7, 6], [1, 7], [4, 9],
                                  [3, 10], [4, 10], [5, 1], [6, 1], [2, 11], [6, 3], [0, 4], [4, 4], [0, 5], [2, 5],
                                  [8, 1], [0, 8], [5, 8], [8, 8], [1, 10], [0, 7], [6, 7], [8, 7], [4, 8], [5, 7],
                                  [3, 11], [1, 5], [1, 2], [3, 0], [9, 11], [4, 6], [2, 1], [2, 4], [4, 2], [5, 5],
                                  [2, 2], [1, 11], [5, 2], [1, 6], [3, 3], [4, 1], [3, 4], [2, 6], [2, 8], [0, 6],
                                  [3, 8], [4, 0], [0, 0], [1, 3], [5, 6], [6, 8], [3, 7], [2, 3], [0, 11], [0, 9],
                                  [1, 9], [4, 3], [3, 6], [10, 3], [9, 8], [3, 1], [3, 2], [5, 4], [1, 0], [1, 8]]
    # mask out heads from the most important one
    if FLAGS.from_most:
        importance_coordinates.reverse()

    # two placeholders for the head coordinates, layer, head
    head_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, ], name='head_mask')
    layer_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, ], name='layer_mask')

    model = modeling.BertModel(
        config=bert_config,
        is_training=False,
        input_ids=input_ids_ph,  # input_ids,
        input_mask=input_mask_ph,  # input_mask,
        token_type_ids=segment_ids_ph,  # segment_ids,
        use_one_hot_embeddings=False,
        head_mask=head_mask_ph,
        layer_mask=layer_mask_ph)

    output_layer = model.get_pooled_output()
    output_weights = tf.get_variable(
        "output_weights", [num_labels, embed_dim],
        initializer=tf.truncated_normal_initializer(stddev=0.02))
    output_bias = tf.get_variable(
        "output_bias", [num_labels], initializer=tf.zeros_initializer())
    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)

    with tf.compat.v1.variable_scope("loss"):
        # for stsb
        if num_labels == 1:
            logits = tf.squeeze(logits, [-1])
            per_example_loss = tf.square(logits - label_ids_ph)
            loss = tf.reduce_mean(per_example_loss)
        else:
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids_ph, depth=num_labels, dtype=tf.float32)
            per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
            loss = tf.reduce_mean(per_example_loss)
            predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)

    # metric and summary
    # metric is tf.metric object, (val, op)
    metric = metric_fn(per_example_loss, label_ids_ph, logits, num_labels, task_name)
    metric_name = list(metric.keys())
    metric_val = [m[0] for m in metric.values()]
    metric_op = [m[1] for m in metric.values()]

    init_checkpoint = FLAGS.init_checkpoint
    tvars = tf.compat.v1.trainable_variables()
    saver_init = tf.train.Saver(tvars)

    # Isolate the variables stored behind the scenes by the metric operation
    var_metric = []
    for key in metric.keys():
        var_metric.extend(tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=key))
    # Define initializer to initialize/reset running variables
    metric_vars_initializer = tf.variables_initializer(var_list=var_metric)

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.compat.v1.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver_init.restore(sess, init_checkpoint)

        # if number of eval examples < 1000, just load it directly, or load by batch.
        if num_actual_eval_examples <= 1000:
            eval_input_ids, eval_input_mask, eval_segment_ids, \
            eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=num_actual_eval_examples,
                                                                     seq_length=seq_length,
                                                                     examples=eval_examples,
                                                                     label_list=label_list,
                                                                     tokenizer=tokenizer)

        # loop over layers, then loop over heads
        idx = 0
        for l in range(n_layers):
            for h in range(n_heads):

                cur_l, cur_h = importance_coordinates[idx]
                coor_mask = importance_coordinates[0:idx + 1]
                head_mask = [head for _, head in coor_mask]
                layer_mask = [layer for layer, _ in coor_mask]

                # if number of eval examples < 1000, just load it directly, or load by batch.
                if num_actual_eval_examples <= 1000:
                    sess.run(metric_vars_initializer)
                    sess.run(metric_op, feed_dict={input_ids_ph: eval_input_ids,
                                                   input_mask_ph: eval_input_mask,
                                                   segment_ids_ph: eval_segment_ids,
                                                   label_ids_ph: eval_label_ids,
                                                   head_mask_ph: head_mask,
                                                   layer_mask_ph: layer_mask})
                    eval_metric_val = sess.run(metric_val)
                else:
                    num_batch_eval = num_actual_eval_examples // batch_size \
                        if num_actual_eval_examples % batch_size == 0 \
                        else num_actual_eval_examples // batch_size + 1
                    id_eval = 0
                    sess.run(metric_vars_initializer)
                    for _ in range(num_batch_eval):
                        eval_input_ids, eval_input_mask, eval_segment_ids, \
                        eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size,
                                                                                 seq_length=seq_length,
                                                                                 examples=eval_examples,
                                                                                 label_list=label_list,
                                                                                 tokenizer=tokenizer,
                                                                                 train_idx_offset=id_eval)
                        id_eval += batch_size
                        sess.run(metric_op, feed_dict={input_ids_ph: eval_input_ids,
                                                       input_mask_ph: eval_input_mask,
                                                       segment_ids_ph: eval_segment_ids,
                                                       label_ids_ph: eval_label_ids,
                                                       head_mask_ph: head_mask,
                                                       layer_mask_ph: layer_mask})
                    eval_metric_val = sess.run(metric_val)

                for name, val in zip(metric_name, eval_metric_val):
                    if name == 'accuracy':
                        output[cur_l][cur_h] = val
                        print("Mask out the {}th head in (Layer {}, Head {}) | {}: {}"
                              .format(idx + 1, cur_l, cur_h, name, val))

                idx += 1

        joblib.dump(output, folder + save_file)
コード例 #4
0
def main(_):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    processors = {
        "cola": ColaProcessor,
        "mnlim": MnliMProcessor,
        "mnlimm": MnliMMProcessor,
        "mrpc": MrpcProcessor,
        "qnli": QnliProcessor,
        "qqp": QqpProcessor,
        "rte": RteProcessor,
        "sst2": Sst2Processor,
        "stsb": StsbProcessor,
        "wnli": WnliProcessor,
        "ax": AxProcessor,
        "mnlimdevastest": MnliMDevAsTestProcessor
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    if not FLAGS.do_train:
        raise ValueError("At least 'do_train' must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.io.gfile.makedirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()
    print("Current task", task_name)

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    # special handling for mnlimdevastest
    if task_name == 'mnlimdevastest':
        task_name = 'mnlim'

    label_list = processor.get_labels()
    print("Label list of current task", label_list)

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    train_examples = processor.get_train_examples(FLAGS.data_dir)
    eval_examples = processor.get_dev_examples(FLAGS.data_dir)
    num_actual_train_examples = len(train_examples)
    num_actual_eval_examples = len(eval_examples)
    print("num_actual_train_examples", num_actual_train_examples)
    print("num_actual_eval_examples", num_actual_eval_examples)

    batch_size = FLAGS.train_batch_size
    epochs = FLAGS.num_train_epochs
    embed_dim = FLAGS.hidden_size  # hidden size, 768 for BERT-base, 512 for BERT-small
    seq_length = FLAGS.max_seq_length
    num_labels = len(label_list)

    # Define some placeholders for the input
    input_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                            shape=[None, seq_length],
                                            name='input_ids')
    input_mask_ph = tf.compat.v1.placeholder(tf.int32,
                                             shape=[None, seq_length],
                                             name='input_mask')
    segment_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                              shape=[None, seq_length],
                                              name='segment_ids')
    label_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                            shape=[
                                                None,
                                            ],
                                            name='label_ids')

    tf.compat.v1.logging.info(
        "Running leave the most important head per layer then fine-tune!")

    num_train_steps = num_actual_train_examples // batch_size * epochs
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    cur_layer = FLAGS.cur_layer
    most_important_head = FLAGS.most_important_head
    print("Current most important head:", cur_layer, most_important_head)

    # this placeholder is to control the flag for the dropout
    keep_prob_ph = tf.compat.v1.placeholder(tf.float32, name="keep_prob")
    is_training_ph = tf.compat.v1.placeholder(tf.bool, name='is_training')

    # two placeholders for the head coordinates, layer, head
    head_mask_ph = tf.compat.v1.placeholder(tf.int32,
                                            shape=[
                                                None,
                                            ],
                                            name='head_mask')
    layer_mask_ph = tf.compat.v1.placeholder(tf.int32,
                                             shape=[
                                                 None,
                                             ],
                                             name='layer_mask')

    model = modeling.BertModel(
        config=bert_config,
        is_training=is_training_ph,
        input_ids=input_ids_ph,  # input_ids,
        input_mask=input_mask_ph,  # input_mask,
        token_type_ids=segment_ids_ph,  # segment_ids,
        use_one_hot_embeddings=False,
        use_estimator=False,
        head_mask=head_mask_ph,
        layer_mask=layer_mask_ph)

    output_layer = model.get_pooled_output()
    output_layer = tf.nn.dropout(output_layer, keep_prob=keep_prob_ph)
    output_weights = tf.get_variable(
        "output_weights", [num_labels, embed_dim],
        initializer=tf.truncated_normal_initializer(stddev=0.02))
    output_bias = tf.get_variable("output_bias", [num_labels],
                                  initializer=tf.zeros_initializer())
    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)

    with tf.compat.v1.variable_scope("loss"):
        # for stsb
        if num_labels == 1:
            logits = tf.squeeze(logits, [-1])
            per_example_loss = tf.square(logits - label_ids_ph)
            loss = tf.reduce_mean(per_example_loss)
        else:
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids_ph,
                                        depth=num_labels,
                                        dtype=tf.float32)
            per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                              axis=-1)
            loss = tf.reduce_mean(per_example_loss)
            predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)

    # metric and summary
    # metric is tf.metric object, (val, op)
    metric = metric_fn(per_example_loss, label_ids_ph, logits, num_labels,
                       task_name)
    metric_name = list(metric.keys())
    metric_val = [m[0] for m in metric.values()]
    metric_op = [m[1] for m in metric.values()]

    log_dir = FLAGS.output_dir + 'layer_{}_head_{}/'.format(
        cur_layer, most_important_head)

    metric_phs = [
        tf.compat.v1.placeholder(tf.float32, name="{}_ph".format(key))
        for key in metric.keys()
    ]
    summaries = [
        tf.compat.v1.summary.scalar(key, metric_phs[i])
        for i, key in enumerate(metric.keys())
    ]
    train_summary_total = tf.summary.merge(summaries)
    eval_summary_total = tf.summary.merge(summaries)

    init_checkpoint = FLAGS.init_checkpoint
    tvars = tf.compat.v1.trainable_variables()
    var_init = [
        v for v in tvars
        if 'output_weights' not in v.name and 'output_bias' not in v.name
    ]
    var_output = [
        v for v in tvars
        if 'output_weights' in v.name or "output_bias" in v.name
    ]

    if not FLAGS.load_from_finetuned:
        # Init from Model0
        saver_init = tf.train.Saver(var_init)
    else:
        # Init from Model1
        saver_init = tf.train.Saver(var_init + var_output)

    var_train = var_init + var_output
    print("Training parameters")
    for v in var_train:
        print(v)

    train_op = optimization.create_optimizer(loss=loss,
                                             init_lr=FLAGS.learning_rate,
                                             num_train_steps=num_train_steps,
                                             num_warmup_steps=num_warmup_steps,
                                             use_tpu=False,
                                             tvars=var_train)

    saver_all = tf.train.Saver(var_list=var_init + var_output, max_to_keep=1)

    # Isolate the variables stored behind the scenes by the metric operation
    var_metric = []
    for key in metric.keys():
        var_metric.extend(
            tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=key))
    # Define initializer to initialize/reset running variables
    metric_vars_initializer = tf.variables_initializer(var_list=var_metric)

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.compat.v1.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver_init.restore(sess, init_checkpoint)

        writer = tf.compat.v1.summary.FileWriter(log_dir + 'log/train/',
                                                 sess.graph)
        writer_eval = tf.compat.v1.summary.FileWriter(log_dir + 'log/eval/')

        # heads need to be masked out in cur_layer
        head_mask = [head for head in range(12) if head != most_important_head]
        layer_mask = [cur_layer for _ in range(11)]

        # if number of eval examples < 1000, just load it directly, or load by batch.
        if num_actual_eval_examples <= 1000:
            eval_input_ids, eval_input_mask, eval_segment_ids, \
            eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=num_actual_eval_examples,
                                                                     seq_length=seq_length,
                                                                     examples=eval_examples,
                                                                     label_list=label_list,
                                                                     tokenizer=tokenizer)

        start_metric = {"eval_{}".format(key): 0 for key in metric_name}
        end_metric = {"eval_{}".format(key): 0 for key in metric_name}

        if FLAGS.do_train:
            tf.logging.info("***** Run training *****")
            step = 1
            for n in range(epochs):

                np.random.shuffle(train_examples)
                num_batch = num_actual_train_examples // batch_size if num_actual_train_examples % batch_size == 0 \
                    else num_actual_train_examples // batch_size + 1
                id = 0

                for b in range(num_batch):

                    input_ids, input_mask, \
                    segment_ids, label_ids, is_real_example = generate_ph_input(batch_size=batch_size,
                                                                                seq_length=seq_length,
                                                                                examples=train_examples,
                                                                                label_list=label_list,
                                                                                tokenizer=tokenizer,
                                                                                train_idx_offset=id)
                    id += batch_size
                    sess.run(metric_vars_initializer)
                    sess.run([train_op] + metric_op,
                             feed_dict={
                                 input_ids_ph: input_ids,
                                 input_mask_ph: input_mask,
                                 segment_ids_ph: segment_ids,
                                 label_ids_ph: label_ids,
                                 is_training_ph: True,
                                 keep_prob_ph: 0.9,
                                 head_mask_ph: head_mask,
                                 layer_mask_ph: layer_mask
                             })
                    train_metric_val = sess.run(metric_val)
                    train_summary_str = sess.run(
                        train_summary_total,
                        feed_dict={
                            ph: value
                            for ph, value in zip(metric_phs, train_metric_val)
                        })
                    writer.add_summary(train_summary_str, step)

                    if step % 100 == 0 or step % num_batch == 0 or step == 1:
                        # evaluate on dev set

                        if num_actual_eval_examples <= 1000:
                            sess.run(metric_vars_initializer)
                            sess.run(metric_op,
                                     feed_dict={
                                         input_ids_ph: eval_input_ids,
                                         input_mask_ph: eval_input_mask,
                                         segment_ids_ph: eval_segment_ids,
                                         label_ids_ph: eval_label_ids,
                                         is_training_ph: False,
                                         keep_prob_ph: 1,
                                         head_mask_ph: head_mask,
                                         layer_mask_ph: layer_mask
                                     })
                            eval_metric_val = sess.run(metric_val)
                            eval_summary_str = sess.run(
                                eval_summary_total,
                                feed_dict={
                                    ph: value
                                    for ph, value in zip(
                                        metric_phs, eval_metric_val)
                                })
                        else:
                            num_batch_eval = num_actual_eval_examples // batch_size \
                                if num_actual_eval_examples % batch_size == 0 \
                                else num_actual_eval_examples // batch_size + 1
                            id_eval = 0

                            sess.run(metric_vars_initializer)
                            for _ in range(num_batch_eval):
                                eval_input_ids, eval_input_mask, eval_segment_ids, \
                                eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size,
                                                                                         seq_length=seq_length,
                                                                                         examples=eval_examples,
                                                                                         label_list=label_list,
                                                                                         tokenizer=tokenizer,
                                                                                         train_idx_offset=id_eval)
                                id_eval += batch_size
                                sess.run(metric_op,
                                         feed_dict={
                                             input_ids_ph: eval_input_ids,
                                             input_mask_ph: eval_input_mask,
                                             segment_ids_ph: eval_segment_ids,
                                             label_ids_ph: eval_label_ids,
                                             is_training_ph: False,
                                             keep_prob_ph: 1,
                                             head_mask_ph: head_mask,
                                             layer_mask_ph: layer_mask
                                         })

                            eval_metric_val = sess.run(metric_val)
                            eval_summary_str = sess.run(
                                eval_summary_total,
                                feed_dict={
                                    ph: value
                                    for ph, value in zip(
                                        metric_phs, eval_metric_val)
                                })

                        writer_eval.add_summary(eval_summary_str, step)

                        if step == 1:
                            for key, val in zip(metric_name, eval_metric_val):
                                start_metric["eval_{}".format(key)] = val
                        if step == epochs * num_batch:
                            for key, val in zip(metric_name, eval_metric_val):
                                end_metric["eval_{}".format(key)] = val

                    if step % 100 == 0 or step % num_batch == 0 or step == 1:
                        train_metric_list = []
                        for i in range(len(train_metric_val)):
                            if metric_name[i] == 'loss':
                                train_metric_list.append(
                                    "{}: %2.4f".format(metric_name[i]) %
                                    train_metric_val[i])
                            else:
                                train_metric_list.append(
                                    "{}: %.4f".format(metric_name[i]) %
                                    train_metric_val[i])
                        train_str = 'Train ' + '|'.join(train_metric_list)

                        eval_metric_list = []
                        for i in range(len(eval_metric_val)):
                            if metric_name[i] == 'loss':
                                eval_metric_list.append(
                                    "{}: %2.4f".format(metric_name[i]) %
                                    eval_metric_val[i])
                            else:
                                eval_metric_list.append(
                                    "{}: %.4f".format(metric_name[i]) %
                                    eval_metric_val[i])
                        eval_str = 'Eval ' + '|'.join(eval_metric_list)

                        print(
                            "Layer {}, leave only one head {} | Epoch: %4d/%4d | Batch: %4d/%4d | {} | {}"
                            .format(cur_layer, most_important_head, train_str,
                                    eval_str) % (n, epochs, b, num_batch))

                    if step % num_batch == 0:
                        saver_all.save(sess,
                                       log_dir + 'layer{}_head_{}'.format(
                                           cur_layer, most_important_head),
                                       global_step=step)

                    step += 1

            writer.close()
            writer_eval.close()

        print("Start metric", start_metric)
        print("End metric", end_metric)

        with tf.io.gfile.GFile(FLAGS.output_dir + 'results.txt',
                               'a') as writer:
            eval_start, eval_end = [], []
            for metric in metric_name:
                if metric != 'loss':
                    eval_start.append("{}: %.4f".format(metric) %
                                      start_metric["eval_{}".format(metric)])
                    eval_end.append("{}: %.4f".format(metric) %
                                    end_metric["eval_{}".format(metric)])

            writer.write(
                "Layer {}, leave only one head {}: Start: {} | End: {}\n".
                format(cur_layer, most_important_head, ','.join(eval_start),
                       ','.join(eval_end)))