Beispiel #1
0
def main(_):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    processors = {
        "cola": ColaProcessor,
        "mnlim": MnliMProcessor,
        "mnlimm": MnliMMProcessor,
        "mrpc": MrpcProcessor,
        "qnli": QnliProcessor,
        "qqp": QqpProcessor,
        "rte": RteProcessor,
        "sst2": Sst2Processor,
        "stsb": StsbProcessor,
        "wnli": WnliProcessor,
        "ax": AxProcessor,
        "mnlimdevastest": MnliMDevAsTestProcessor
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    if not FLAGS.do_eval:
        raise ValueError("At least 'do_eval' must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.io.gfile.makedirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()
    print("Current task", task_name)

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    # special handling for mnlimdevastest
    if task_name == 'mnlimdevastest':
        task_name = 'mnlim'

    label_list = processor.get_labels()
    print("Label list of current task", label_list)

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    eval_examples = processor.get_dev_examples(FLAGS.data_dir)
    num_actual_eval_examples = len(eval_examples)
    print("num_actual_eval_examples", num_actual_eval_examples)

    batch_size = FLAGS.eval_batch_size
    embed_dim = FLAGS.hidden_size  # hidden size, 768 for BERT-base, 512 for BERT-small
    seq_length = FLAGS.max_seq_length
    num_labels = len(label_list)

    # Define some placeholders for the input
    input_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                            shape=[None, seq_length],
                                            name='input_ids')
    input_mask_ph = tf.compat.v1.placeholder(tf.int32,
                                             shape=[None, seq_length],
                                             name='input_mask')
    segment_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                              shape=[None, seq_length],
                                              name='segment_ids')
    label_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                            shape=[
                                                None,
                                            ],
                                            name='label_ids')

    tf.compat.v1.logging.info(
        "Running single-head masking out and direct evaluation!")

    # we want to mask out the individual head and then evaluate. So there are 12 layers * 12 heads results.
    n_layers = 12
    n_heads = 12
    folder = FLAGS.output_dir
    save_file = 'single_head_mask.pickle'
    output = np.zeros((n_layers, n_heads))

    # two placeholders for the head coordinates, layer, head
    head_mask_ph = tf.compat.v1.placeholder(tf.int32,
                                            shape=[
                                                None,
                                            ],
                                            name='head_mask')
    layer_mask_ph = tf.compat.v1.placeholder(tf.int32,
                                             shape=[
                                                 None,
                                             ],
                                             name='layer_mask')

    model = modeling.BertModel(
        config=bert_config,
        is_training=False,
        input_ids=input_ids_ph,  # input_ids,
        input_mask=input_mask_ph,  # input_mask,
        token_type_ids=segment_ids_ph,  # segment_ids,
        use_one_hot_embeddings=False,
        head_mask=head_mask_ph,
        layer_mask=layer_mask_ph)

    output_layer = model.get_pooled_output()
    output_weights = tf.get_variable(
        "output_weights", [num_labels, embed_dim],
        initializer=tf.truncated_normal_initializer(stddev=0.02))
    output_bias = tf.get_variable("output_bias", [num_labels],
                                  initializer=tf.zeros_initializer())
    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)

    with tf.compat.v1.variable_scope("loss"):
        # for stsb
        if num_labels == 1:
            logits = tf.squeeze(logits, [-1])
            per_example_loss = tf.square(logits - label_ids_ph)
            loss = tf.reduce_mean(per_example_loss)
        else:
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids_ph,
                                        depth=num_labels,
                                        dtype=tf.float32)
            per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                              axis=-1)
            loss = tf.reduce_mean(per_example_loss)
            predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)

    # metric and summary
    # metric is tf.metric object, (val, op)
    metric = metric_fn(per_example_loss, label_ids_ph, logits, num_labels,
                       task_name)
    metric_name = list(metric.keys())
    metric_val = [m[0] for m in metric.values()]
    metric_op = [m[1] for m in metric.values()]

    init_checkpoint = FLAGS.init_checkpoint
    tvars = tf.compat.v1.trainable_variables()
    saver_init = tf.train.Saver(tvars)

    # Isolate the variables stored behind the scenes by the metric operation
    var_metric = []
    for key in metric.keys():
        var_metric.extend(
            tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=key))
    # Define initializer to initialize/reset running variables
    metric_vars_initializer = tf.variables_initializer(var_list=var_metric)

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.compat.v1.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver_init.restore(sess, init_checkpoint)

        # if number of eval examples < 1000, just load it directly, or load by batch.
        if num_actual_eval_examples <= 1000:
            eval_input_ids, eval_input_mask, eval_segment_ids, \
            eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=num_actual_eval_examples,
                                                                     seq_length=seq_length,
                                                                     examples=eval_examples,
                                                                     label_list=label_list,
                                                                     tokenizer=tokenizer)

        # loop over layers, then loop over heads
        for l in range(n_layers):
            for h in range(n_heads):

                cur_l, cur_h = l, h
                head_mask = [h]
                layer_mask = [l]

                # if number of eval examples < 1000, just load it directly, or load by batch.
                if num_actual_eval_examples <= 1000:
                    sess.run(metric_vars_initializer)
                    sess.run(metric_op,
                             feed_dict={
                                 input_ids_ph: eval_input_ids,
                                 input_mask_ph: eval_input_mask,
                                 segment_ids_ph: eval_segment_ids,
                                 label_ids_ph: eval_label_ids,
                                 head_mask_ph: head_mask,
                                 layer_mask_ph: layer_mask
                             })
                    eval_metric_val = sess.run(metric_val)
                else:
                    num_batch_eval = num_actual_eval_examples // batch_size \
                        if num_actual_eval_examples % batch_size == 0 \
                        else num_actual_eval_examples // batch_size + 1
                    id_eval = 0
                    sess.run(metric_vars_initializer)
                    for _ in range(num_batch_eval):
                        eval_input_ids, eval_input_mask, eval_segment_ids, \
                        eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size,
                                                                                 seq_length=seq_length,
                                                                                 examples=eval_examples,
                                                                                 label_list=label_list,
                                                                                 tokenizer=tokenizer,
                                                                                 train_idx_offset=id_eval)
                        id_eval += batch_size
                        sess.run(metric_op,
                                 feed_dict={
                                     input_ids_ph: eval_input_ids,
                                     input_mask_ph: eval_input_mask,
                                     segment_ids_ph: eval_segment_ids,
                                     label_ids_ph: eval_label_ids,
                                     head_mask_ph: head_mask,
                                     layer_mask_ph: layer_mask
                                 })
                    eval_metric_val = sess.run(metric_val)

                for name, val in zip(metric_name, eval_metric_val):
                    if name == 'accuracy':
                        output[cur_l][cur_h] = val
                        print(
                            "Mask out the head in (Layer {}, Head {}) | {}: {}"
                            .format(cur_l, cur_h, name, val))

        joblib.dump(output, folder + save_file)
Beispiel #2
0
def main(_):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    processors = {
        "cola": ColaProcessor,
        "mnlim": MnliMProcessor,
        "mnlimm": MnliMMProcessor,
        "mrpc": MrpcProcessor,
        "qnli": QnliProcessor,
        "qqp": QqpProcessor,
        "rte": RteProcessor,
        "sst2": Sst2Processor,
        "stsb": StsbProcessor,
        "wnli": WnliProcessor,
        "ax": AxProcessor
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    if not FLAGS.do_eval:
        raise ValueError("At least 'do_eval' must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.io.gfile.makedirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()
    print("Current task", task_name)

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    label_list = processor.get_labels()
    print("Label list of current task", label_list)

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

    train_examples = processor.get_train_examples(FLAGS.data_dir)
    eval_examples = processor.get_dev_examples(FLAGS.data_dir)
    num_actual_train_examples = len(train_examples)
    num_actual_eval_examples = len(eval_examples)
    print("num_actual_train_examples", num_actual_train_examples)
    print("num_actual_eval_examples", num_actual_eval_examples)

    batch_size = FLAGS.train_batch_size
    embed_dim = FLAGS.hidden_size  # hidden size, 768 for BERT-base, 512 for BERT-small
    seq_length = FLAGS.max_seq_length
    num_labels = len(label_list)

    # Define some placeholders for the input
    input_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_ids')
    input_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_mask')
    segment_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='segment_ids')
    label_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, ], name='label_ids')

    tf.compat.v1.logging.info("Running cumulatively mask out heads and direct evaluation!")

    # we want to mask out the individual head and then evaluate. So there are 12 layers * 12 heads results.
    n_layers = 12
    n_heads = 12
    folder = FLAGS.output_dir
    save_file = 'cumulative_heads_mask.pickle'
    output = np.zeros((n_layers, n_heads))

    importance_coordinates = []
    if FLAGS.importance_setting == 'l2_norm':
        #######################
        # 12 * 12, layer * head, importance increases from 1 to 144.
        # [[127, 63, 67, 72, 91, 93, 124, 100, 96, 134, 15, 133],
        #  [143, 60, 107, 128, 57, 106, 118, 83, 144, 135, 99, 116],
        #  [30, 111, 115, 132, 112, 94, 122, 66, 123, 40, 75, 89],
        #  [108, 140, 141, 119, 121, 62, 137, 131, 125, 70, 85, 105],
        #  [126, 120, 113, 136, 92, 79, 110, 74, 103, 84, 86, 53],
        #  [69, 87, 117, 80, 142, 114, 129, 104, 97, 18, 52, 77],
        #  [81, 88, 48, 90, 56, 50, 58, 101, 130, 64, 35, 46],
        #  [20, 41, 38, 32, 71, 59, 82, 43, 78, 55, 47, 37],
        #  [24, 95, 27, 44, 65, 12, 28, 102, 98, 23, 14, 19],
        #  [17, 16, 2, 1, 9, 13, 68, 4, 139, 7, 21, 109],
        #  [76, 26, 8, 138, 10, 29, 31, 54, 6, 36, 3, 49],
        #  [61, 51, 42, 45, 34, 5, 73, 33, 25, 22, 39, 11]]
        #######################
        # sorted head coordinate array by importance (L2 weight magnitude), (layer, head)
        # least 1 --> most 144
        importance_coordinates = [[9, 3], [9, 2], [10, 10], [9, 7], [11, 5], [10, 8], [9, 9], [10, 2], [9, 4],
                                  [10, 4], [11, 11], [8, 5], [9, 5], [8, 10], [0, 10], [9, 1], [9, 0], [5, 9],
                                  [8, 11], [7, 0], [9, 10], [11, 9], [8, 9], [8, 0], [11, 8], [10, 1], [8, 2],
                                  [8, 6], [10, 5], [2, 0], [10, 6], [7, 3], [11, 7], [11, 4], [6, 10], [10, 9],
                                  [7, 11], [7, 2], [11, 10], [2, 9], [7, 1], [11, 2], [7, 7], [8, 3], [11, 3],
                                  [6, 11], [7, 10], [6, 2], [10, 11], [6, 5], [11, 1], [5, 10], [4, 11], [10, 7],
                                  [7, 9], [6, 4], [1, 4], [6, 6], [7, 5], [1, 1], [11, 0], [3, 5], [0, 1], [6, 9],
                                  [8, 4], [2, 7], [0, 2], [9, 6], [5, 0], [3, 9], [7, 4], [0, 3], [11, 6], [4, 7],
                                  [2, 10], [10, 0], [5, 11], [7, 8], [4, 5], [5, 3], [6, 0], [7, 6], [1, 7], [4, 9],
                                  [3, 10], [4, 10], [5, 1], [6, 1], [2, 11], [6, 3], [0, 4], [4, 4], [0, 5], [2, 5],
                                  [8, 1], [0, 8], [5, 8], [8, 8], [1, 10], [0, 7], [6, 7], [8, 7], [4, 8], [5, 7],
                                  [3, 11], [1, 5], [1, 2], [3, 0], [9, 11], [4, 6], [2, 1], [2, 4], [4, 2], [5, 5],
                                  [2, 2], [1, 11], [5, 2], [1, 6], [3, 3], [4, 1], [3, 4], [2, 6], [2, 8], [0, 6],
                                  [3, 8], [4, 0], [0, 0], [1, 3], [5, 6], [6, 8], [3, 7], [2, 3], [0, 11], [0, 9],
                                  [1, 9], [4, 3], [3, 6], [10, 3], [9, 8], [3, 1], [3, 2], [5, 4], [1, 0], [1, 8]]
    elif FLAGS.importance_setting == 'per_head_score':
        #######################
        # 12 * 12, layer * head, importance increases from 1 to 144.
        # [[127, 63, 67, 72, 91, 93, 124, 100, 96, 134, 15, 133],
        #  [143, 60, 107, 128, 57, 106, 118, 83, 144, 135, 99, 116],
        #  [30, 111, 115, 132, 112, 94, 122, 66, 123, 40, 75, 89],
        #  [108, 140, 141, 119, 121, 62, 137, 131, 125, 70, 85, 105],
        #  [126, 120, 113, 136, 92, 79, 110, 74, 103, 84, 86, 53],
        #  [69, 87, 117, 80, 142, 114, 129, 104, 97, 18, 52, 77],
        #  [81, 88, 48, 90, 56, 50, 58, 101, 130, 64, 35, 46],
        #  [20, 41, 38, 32, 71, 59, 82, 43, 78, 55, 47, 37],
        #  [24, 95, 27, 44, 65, 12, 28, 102, 98, 23, 14, 19],
        #  [17, 16, 2, 1, 9, 13, 68, 4, 139, 7, 21, 109],
        #  [76, 26, 8, 138, 10, 29, 31, 54, 6, 36, 3, 49],
        #  [61, 51, 42, 45, 34, 5, 73, 33, 25, 22, 39, 11]]
        #######################
        # sorted head coordinate array by importance (L2 weight magnitude), (layer, head)
        # least 1 --> most 144
        importance_coordinates = [[9, 3], [9, 2], [10, 10], [9, 7], [11, 5], [10, 8], [9, 9], [10, 2], [9, 4],
                                  [10, 4], [11, 11], [8, 5], [9, 5], [8, 10], [0, 10], [9, 1], [9, 0], [5, 9],
                                  [8, 11], [7, 0], [9, 10], [11, 9], [8, 9], [8, 0], [11, 8], [10, 1], [8, 2],
                                  [8, 6], [10, 5], [2, 0], [10, 6], [7, 3], [11, 7], [11, 4], [6, 10], [10, 9],
                                  [7, 11], [7, 2], [11, 10], [2, 9], [7, 1], [11, 2], [7, 7], [8, 3], [11, 3],
                                  [6, 11], [7, 10], [6, 2], [10, 11], [6, 5], [11, 1], [5, 10], [4, 11], [10, 7],
                                  [7, 9], [6, 4], [1, 4], [6, 6], [7, 5], [1, 1], [11, 0], [3, 5], [0, 1], [6, 9],
                                  [8, 4], [2, 7], [0, 2], [9, 6], [5, 0], [3, 9], [7, 4], [0, 3], [11, 6], [4, 7],
                                  [2, 10], [10, 0], [5, 11], [7, 8], [4, 5], [5, 3], [6, 0], [7, 6], [1, 7], [4, 9],
                                  [3, 10], [4, 10], [5, 1], [6, 1], [2, 11], [6, 3], [0, 4], [4, 4], [0, 5], [2, 5],
                                  [8, 1], [0, 8], [5, 8], [8, 8], [1, 10], [0, 7], [6, 7], [8, 7], [4, 8], [5, 7],
                                  [3, 11], [1, 5], [1, 2], [3, 0], [9, 11], [4, 6], [2, 1], [2, 4], [4, 2], [5, 5],
                                  [2, 2], [1, 11], [5, 2], [1, 6], [3, 3], [4, 1], [3, 4], [2, 6], [2, 8], [0, 6],
                                  [3, 8], [4, 0], [0, 0], [1, 3], [5, 6], [6, 8], [3, 7], [2, 3], [0, 11], [0, 9],
                                  [1, 9], [4, 3], [3, 6], [10, 3], [9, 8], [3, 1], [3, 2], [5, 4], [1, 0], [1, 8]]
    # mask out heads from the most important one
    if FLAGS.from_most:
        importance_coordinates.reverse()

    # two placeholders for the head coordinates, layer, head
    head_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, ], name='head_mask')
    layer_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, ], name='layer_mask')

    model = modeling.BertModel(
        config=bert_config,
        is_training=False,
        input_ids=input_ids_ph,  # input_ids,
        input_mask=input_mask_ph,  # input_mask,
        token_type_ids=segment_ids_ph,  # segment_ids,
        use_one_hot_embeddings=False,
        head_mask=head_mask_ph,
        layer_mask=layer_mask_ph)

    output_layer = model.get_pooled_output()
    output_weights = tf.get_variable(
        "output_weights", [num_labels, embed_dim],
        initializer=tf.truncated_normal_initializer(stddev=0.02))
    output_bias = tf.get_variable(
        "output_bias", [num_labels], initializer=tf.zeros_initializer())
    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)

    with tf.compat.v1.variable_scope("loss"):
        # for stsb
        if num_labels == 1:
            logits = tf.squeeze(logits, [-1])
            per_example_loss = tf.square(logits - label_ids_ph)
            loss = tf.reduce_mean(per_example_loss)
        else:
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids_ph, depth=num_labels, dtype=tf.float32)
            per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
            loss = tf.reduce_mean(per_example_loss)
            predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)

    # metric and summary
    # metric is tf.metric object, (val, op)
    metric = metric_fn(per_example_loss, label_ids_ph, logits, num_labels, task_name)
    metric_name = list(metric.keys())
    metric_val = [m[0] for m in metric.values()]
    metric_op = [m[1] for m in metric.values()]

    init_checkpoint = FLAGS.init_checkpoint
    tvars = tf.compat.v1.trainable_variables()
    saver_init = tf.train.Saver(tvars)

    # Isolate the variables stored behind the scenes by the metric operation
    var_metric = []
    for key in metric.keys():
        var_metric.extend(tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=key))
    # Define initializer to initialize/reset running variables
    metric_vars_initializer = tf.variables_initializer(var_list=var_metric)

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.compat.v1.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver_init.restore(sess, init_checkpoint)

        # if number of eval examples < 1000, just load it directly, or load by batch.
        if num_actual_eval_examples <= 1000:
            eval_input_ids, eval_input_mask, eval_segment_ids, \
            eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=num_actual_eval_examples,
                                                                     seq_length=seq_length,
                                                                     examples=eval_examples,
                                                                     label_list=label_list,
                                                                     tokenizer=tokenizer)

        # loop over layers, then loop over heads
        idx = 0
        for l in range(n_layers):
            for h in range(n_heads):

                cur_l, cur_h = importance_coordinates[idx]
                coor_mask = importance_coordinates[0:idx + 1]
                head_mask = [head for _, head in coor_mask]
                layer_mask = [layer for layer, _ in coor_mask]

                # if number of eval examples < 1000, just load it directly, or load by batch.
                if num_actual_eval_examples <= 1000:
                    sess.run(metric_vars_initializer)
                    sess.run(metric_op, feed_dict={input_ids_ph: eval_input_ids,
                                                   input_mask_ph: eval_input_mask,
                                                   segment_ids_ph: eval_segment_ids,
                                                   label_ids_ph: eval_label_ids,
                                                   head_mask_ph: head_mask,
                                                   layer_mask_ph: layer_mask})
                    eval_metric_val = sess.run(metric_val)
                else:
                    num_batch_eval = num_actual_eval_examples // batch_size \
                        if num_actual_eval_examples % batch_size == 0 \
                        else num_actual_eval_examples // batch_size + 1
                    id_eval = 0
                    sess.run(metric_vars_initializer)
                    for _ in range(num_batch_eval):
                        eval_input_ids, eval_input_mask, eval_segment_ids, \
                        eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size,
                                                                                 seq_length=seq_length,
                                                                                 examples=eval_examples,
                                                                                 label_list=label_list,
                                                                                 tokenizer=tokenizer,
                                                                                 train_idx_offset=id_eval)
                        id_eval += batch_size
                        sess.run(metric_op, feed_dict={input_ids_ph: eval_input_ids,
                                                       input_mask_ph: eval_input_mask,
                                                       segment_ids_ph: eval_segment_ids,
                                                       label_ids_ph: eval_label_ids,
                                                       head_mask_ph: head_mask,
                                                       layer_mask_ph: layer_mask})
                    eval_metric_val = sess.run(metric_val)

                for name, val in zip(metric_name, eval_metric_val):
                    if name == 'accuracy':
                        output[cur_l][cur_h] = val
                        print("Mask out the {}th head in (Layer {}, Head {}) | {}: {}"
                              .format(idx + 1, cur_l, cur_h, name, val))

                idx += 1

        joblib.dump(output, folder + save_file)
Beispiel #3
0
def main(_):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    processors = {
        "cola": ColaProcessor,
        "mnlim": MnliMProcessor,
        "mnlimm": MnliMMProcessor,
        "mrpc": MrpcProcessor,
        "qnli": QnliProcessor,
        "qqp": QqpProcessor,
        "rte": RteProcessor,
        "sst2": Sst2Processor,
        "stsb": StsbProcessor,
        "wnli": WnliProcessor,
        "ax": AxProcessor,
        "mnlimdevastest": MnliMDevAsTestProcessor
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    if not FLAGS.do_train and \
       not FLAGS.do_eval and \
       not FLAGS.do_pred:
        raise ValueError(
            "At least one of 'do_train', 'do_eval' or 'do_pred' must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.io.gfile.makedirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()
    print("Current task", task_name)

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    # special handling for mnlimdevastest
    if task_name == 'mnlimdevastest':
        task_name = 'mnlim'

    label_list = processor.get_labels()
    print("Label list of current task", label_list)

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    train_examples = processor.get_train_examples(FLAGS.data_dir)
    eval_examples = processor.get_dev_examples(FLAGS.data_dir)
    num_actual_train_examples = len(train_examples)
    num_actual_eval_examples = len(eval_examples)
    print("num_actual_train_examples", num_actual_train_examples)
    print("num_actual_eval_examples", num_actual_eval_examples)
    if FLAGS.do_pred:
        test_examples = processor.get_test_examples(FLAGS.data_dir)
        num_actual_test_examples = len(test_examples)
        print("num_actual_test_examples", num_actual_test_examples)

    batch_size = FLAGS.train_batch_size
    epochs = FLAGS.num_train_epochs
    embed_dim = FLAGS.hidden_size  # hidden size, 768 for BERT-base, 512 for BERT-small
    seq_length = FLAGS.max_seq_length
    num_labels = len(label_list)

    # Define some placeholders for the input
    input_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                            shape=[None, seq_length],
                                            name='input_ids')
    input_mask_ph = tf.compat.v1.placeholder(tf.int32,
                                             shape=[None, seq_length],
                                             name='input_mask')
    segment_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                              shape=[None, seq_length],
                                              name='segment_ids')
    label_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                            shape=[
                                                None,
                                            ],
                                            name='label_ids')

    tf.compat.v1.logging.info("Running swapping layers!")

    num_train_steps = num_actual_train_examples // batch_size * epochs
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    # Swapped layers
    layer_folder_name = FLAGS.layers
    if FLAGS.layers is None:
        raise ValueError("In swapping experiments, layers must not be None. ")
    layers = list(map(int, FLAGS.layers.split(',')))
    layer_order = list(range(12)) if embed_dim == 768 else list(range(6))
    layer1, layer2 = layers[0], layers[1]
    if len(layers) == 2:
        layers.sort()
        if layer1 != layer2:
            layer_order[layer1], layer_order[layer2] = layer_order[
                layer2], layer_order[layer1]
        else:
            raise ValueError("Two layers should be different! ")
    else:
        layer_order = layer_order[:layers[0]] + layers[::-1] + layer_order[
            layers[-1] + 1:]
    print("Current layer order: ", layer_order)

    # this placeholder is to control the flag for the dropout
    keep_prob_ph = tf.compat.v1.placeholder(tf.float32, name="keep_prob")
    is_training_ph = tf.compat.v1.placeholder(tf.bool, name='is_training')

    model = modeling.BertModel(
        config=bert_config,
        is_training=is_training_ph,
        input_ids=input_ids_ph,  # input_ids,
        input_mask=input_mask_ph,  # input_mask,
        token_type_ids=segment_ids_ph,  # segment_ids,
        use_one_hot_embeddings=False,
        layer_order=layer_order,
        use_estimator=False)

    output_layer = model.get_pooled_output()
    output_layer = tf.nn.dropout(output_layer, keep_prob=keep_prob_ph)
    output_weights = tf.get_variable(
        "output_weights", [num_labels, embed_dim],
        initializer=tf.truncated_normal_initializer(stddev=0.02))
    output_bias = tf.get_variable("output_bias", [num_labels],
                                  initializer=tf.zeros_initializer())
    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)

    with tf.compat.v1.variable_scope("loss"):
        # for stsb
        if num_labels == 1:
            logits = tf.squeeze(logits, [-1])
            per_example_loss = tf.square(logits - label_ids_ph)
            loss = tf.reduce_mean(per_example_loss)
        else:
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids_ph,
                                        depth=num_labels,
                                        dtype=tf.float32)
            per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                              axis=-1)
            loss = tf.reduce_mean(per_example_loss)
            predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)

    # metric and summary
    # metric is tf.metric object, (val, op)
    metric = metric_fn(per_example_loss, label_ids_ph, logits, num_labels,
                       task_name)
    metric_name = list(metric.keys())
    metric_val = [m[0] for m in metric.values()]
    metric_op = [m[1] for m in metric.values()]

    metric_phs = [
        tf.compat.v1.placeholder(tf.float32, name="{}_ph".format(key))
        for key in metric.keys()
    ]
    summaries = [
        tf.compat.v1.summary.scalar(key, metric_phs[i])
        for i, key in enumerate(metric.keys())
    ]
    train_summary_total = tf.summary.merge(summaries)
    eval_summary_total = tf.summary.merge(summaries)

    log_dir = FLAGS.output_dir + 'layer_{}/'.format(layer_folder_name)

    init_checkpoint = FLAGS.init_checkpoint
    tvars = tf.compat.v1.trainable_variables()
    var_init = [
        v for v in tvars
        if 'output_weights' not in v.name and 'output_bias' not in v.name
    ]
    var_output = [
        v for v in tvars
        if 'output_weights' in v.name or "output_bias" in v.name
    ]

    if not FLAGS.load_from_finetuned:
        # Init from Model0
        saver_init = tf.train.Saver(var_init)
    else:
        # Init from Model1
        saver_init = tf.train.Saver(var_init + var_output)

    var_train = var_init + var_output
    print("Training parameters")
    for v in var_train:
        print(v)

    train_op = optimization.create_optimizer(loss=loss,
                                             init_lr=FLAGS.learning_rate,
                                             num_train_steps=num_train_steps,
                                             num_warmup_steps=num_warmup_steps,
                                             use_tpu=False,
                                             tvars=var_train)

    saver_all = tf.train.Saver(var_list=var_train, max_to_keep=1)

    # Isolate the variables stored behind the scenes by the metric operation
    var_metric = []
    for key in metric.keys():
        var_metric.extend(
            tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=key))
    # Define initializer to initialize/reset running variables
    metric_vars_initializer = tf.variables_initializer(var_list=var_metric)

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.compat.v1.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver_init.restore(sess, init_checkpoint)

        writer = tf.compat.v1.summary.FileWriter(log_dir + 'log/train/',
                                                 sess.graph)
        writer_eval = tf.compat.v1.summary.FileWriter(log_dir + 'log/eval/')

        # if number of eval examples < 1000, just load it directly, or load by batch.
        if num_actual_eval_examples <= 1000:
            eval_input_ids, eval_input_mask, eval_segment_ids, \
            eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=num_actual_eval_examples,
                                                                     seq_length=seq_length,
                                                                     examples=eval_examples,
                                                                     label_list=label_list,
                                                                     tokenizer=tokenizer)

        start_metric = {"eval_{}".format(key): 0 for key in metric_name}
        if FLAGS.do_train:
            tf.logging.info("***** Run training *****")
            step = 1
            for n in range(epochs):

                np.random.shuffle(train_examples)
                num_batch = num_actual_train_examples // batch_size if num_actual_train_examples % batch_size == 0 \
                    else num_actual_train_examples // batch_size + 1
                id = 0

                for b in range(num_batch):

                    input_ids, input_mask, \
                    segment_ids, label_ids, is_real_example = generate_ph_input(batch_size=batch_size,
                                                                                seq_length=seq_length,
                                                                                examples=train_examples,
                                                                                label_list=label_list,
                                                                                tokenizer=tokenizer,
                                                                                train_idx_offset=id)
                    id += batch_size

                    sess.run(metric_vars_initializer)
                    sess.run([train_op] + metric_op,
                             feed_dict={
                                 input_ids_ph: input_ids,
                                 input_mask_ph: input_mask,
                                 segment_ids_ph: segment_ids,
                                 label_ids_ph: label_ids,
                                 is_training_ph: True,
                                 keep_prob_ph: 0.9
                             })
                    train_metric_val = sess.run(metric_val)
                    train_summary_str = sess.run(
                        train_summary_total,
                        feed_dict={
                            ph: value
                            for ph, value in zip(metric_phs, train_metric_val)
                        })
                    writer.add_summary(train_summary_str, step)

                    if step % 100 == 0 or step % num_batch == 0 or step == 1:
                        # evaluate on dev set

                        if num_actual_eval_examples <= 1000:

                            sess.run(metric_vars_initializer)
                            sess.run(metric_op,
                                     feed_dict={
                                         input_ids_ph: eval_input_ids,
                                         input_mask_ph: eval_input_mask,
                                         segment_ids_ph: eval_segment_ids,
                                         label_ids_ph: eval_label_ids,
                                         is_training_ph: False,
                                         keep_prob_ph: 1
                                     })
                            eval_metric_val = sess.run(metric_val)
                            eval_summary_str = sess.run(
                                eval_summary_total,
                                feed_dict={
                                    ph: value
                                    for ph, value in zip(
                                        metric_phs, eval_metric_val)
                                })

                        else:
                            num_batch_eval = num_actual_eval_examples // batch_size \
                                if num_actual_eval_examples % batch_size == 0 \
                                else num_actual_eval_examples // batch_size + 1
                            id_eval = 0

                            sess.run(metric_vars_initializer)
                            for _ in range(num_batch_eval):
                                eval_input_ids, eval_input_mask, eval_segment_ids, \
                                eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size,
                                                                                         seq_length=seq_length,
                                                                                         examples=eval_examples,
                                                                                         label_list=label_list,
                                                                                         tokenizer=tokenizer,
                                                                                         train_idx_offset=id_eval)
                                id_eval += batch_size
                                sess.run(metric_op,
                                         feed_dict={
                                             input_ids_ph: eval_input_ids,
                                             input_mask_ph: eval_input_mask,
                                             segment_ids_ph: eval_segment_ids,
                                             label_ids_ph: eval_label_ids,
                                             is_training_ph: False,
                                             keep_prob_ph: 1
                                         })

                            eval_metric_val = sess.run(metric_val)
                            eval_summary_str = sess.run(
                                eval_summary_total,
                                feed_dict={
                                    ph: value
                                    for ph, value in zip(
                                        metric_phs, eval_metric_val)
                                })

                        writer_eval.add_summary(eval_summary_str, step)

                        if step == 1:
                            for key, val in zip(metric_name, eval_metric_val):
                                start_metric["eval_{}".format(key)] = val

                    if step % 100 == 0 or step % num_batch == 0 or step == 1:
                        train_metric_list = []
                        for i in range(len(train_metric_val)):
                            if metric_name[i] == 'loss':
                                train_metric_list.append(
                                    "{}: %2.4f".format(metric_name[i]) %
                                    train_metric_val[i])
                            else:
                                train_metric_list.append(
                                    "{}: %.4f".format(metric_name[i]) %
                                    train_metric_val[i])
                        train_str = 'Train ' + '|'.join(train_metric_list)

                        eval_metric_list = []
                        for i in range(len(eval_metric_val)):
                            if metric_name[i] == 'loss':
                                eval_metric_list.append(
                                    "{}: %2.4f".format(metric_name[i]) %
                                    eval_metric_val[i])
                            else:
                                eval_metric_list.append(
                                    "{}: %.4f".format(metric_name[i]) %
                                    eval_metric_val[i])
                        eval_str = 'Eval ' + '|'.join(eval_metric_list)

                        print(
                            "Swap layer order {} | Epoch: %4d/%4d | Batch: %4d/%4d | {} | {}"
                            .format(layer_folder_name, train_str, eval_str) %
                            (n, epochs, b, num_batch))

                    if step % num_batch == 0:
                        saver_all.save(sess,
                                       log_dir +
                                       'swap_{}'.format(layer_folder_name),
                                       global_step=step)

                    step += 1

            writer.close()
            writer_eval.close()

        end_metric = {"eval_{}".format(key): 0 for key in metric_name}
        if FLAGS.do_eval:
            tf.logging.info("***** Run evaluation *****")
            if num_actual_eval_examples <= 1000:

                sess.run(metric_vars_initializer)
                sess.run(metric_op,
                         feed_dict={
                             input_ids_ph: eval_input_ids,
                             input_mask_ph: eval_input_mask,
                             segment_ids_ph: eval_segment_ids,
                             label_ids_ph: eval_label_ids,
                             is_training_ph: False,
                             keep_prob_ph: 1
                         })
                eval_metric_val = sess.run(metric_val)
                preds = sess.run(predictions,
                                 feed_dict={
                                     input_ids_ph: eval_input_ids,
                                     input_mask_ph: eval_input_mask,
                                     segment_ids_ph: eval_segment_ids,
                                     label_ids_ph: eval_label_ids,
                                     is_training_ph: False,
                                     keep_prob_ph: 1
                                 })
                eval_label_ids_lst = eval_label_ids
            else:
                num_batch_eval = num_actual_eval_examples // batch_size \
                    if num_actual_eval_examples % batch_size == 0 \
                    else num_actual_eval_examples // batch_size + 1
                id_eval = 0
                preds = np.zeros(num_actual_eval_examples)
                eval_label_ids_lst = np.zeros(num_actual_eval_examples)

                sess.run(metric_vars_initializer)
                for i in range(num_batch_eval):
                    eval_input_ids, eval_input_mask, eval_segment_ids, \
                    eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size,
                                                                             seq_length=seq_length,
                                                                             examples=eval_examples,
                                                                             label_list=label_list,
                                                                             tokenizer=tokenizer,
                                                                             train_idx_offset=id_eval)
                    id_eval += batch_size
                    sess.run(metric_op,
                             feed_dict={
                                 input_ids_ph: eval_input_ids,
                                 input_mask_ph: eval_input_mask,
                                 segment_ids_ph: eval_segment_ids,
                                 label_ids_ph: eval_label_ids,
                                 is_training_ph: False,
                                 keep_prob_ph: 1
                             })
                    pred = sess.run(predictions,
                                    feed_dict={
                                        input_ids_ph: eval_input_ids,
                                        input_mask_ph: eval_input_mask,
                                        segment_ids_ph: eval_segment_ids,
                                        label_ids_ph: eval_label_ids,
                                        is_training_ph: False,
                                        keep_prob_ph: 1
                                    })
                    preds[i * batch_size:min(id_eval, num_actual_eval_examples
                                             )] = pred[:]
                    eval_label_ids_lst[i * batch_size:min(
                        id_eval, num_actual_eval_examples)] = eval_label_ids[:]
                eval_metric_val = sess.run(metric_val)

            for key, val in zip(metric_name, eval_metric_val):
                end_metric["eval_{}".format(key)] = val

            output_predict_file = os.path.join(log_dir, 'dev_predictions.tsv')
            writer_output = tf.io.gfile.GFile(output_predict_file, "w")
            preds = preds.astype(int)
            eval_label_ids_lst = eval_label_ids_lst.astype(int)

            num_written_lines = 0
            if task_name != 'stsb':
                writer_output.write(
                    "ID \t Label ID \t Label \t Ground Truth ID \t Ground Truth \n"
                )
            else:
                writer_output.write("ID \t Label \n")
            for (i, pred) in enumerate(preds):
                if task_name != 'stsb':
                    writer_output.write("{} \t {} \t {} \t {} \t {} \n".format(
                        i, pred, label_list[pred], eval_label_ids_lst[i],
                        label_list[eval_label_ids_lst[i]]))
                else:
                    writer_output.write("{} \t {} \n".format(
                        num_written_lines, pred))
            writer_output.close()
            tf.logging.info("***** Finish writing *****")

        print("Start metric", start_metric)
        print("End metric", end_metric)

        test_metric = {"test_{}".format(key): 0 for key in metric_name}
        if FLAGS.do_pred:
            # if number of test examples < 1000, just load it directly, or load by batch.
            # prediction
            tf.logging.info("***** Predict results *****")
            if num_actual_test_examples <= 1000:
                test_input_ids, test_input_mask, test_segment_ids, \
                test_label_ids, test_is_real_example = generate_ph_input(batch_size=num_actual_test_examples,
                                                                         seq_length=seq_length,
                                                                         examples=test_examples,
                                                                         label_list=label_list,
                                                                         tokenizer=tokenizer)
                sess.run(metric_vars_initializer)
                sess.run(metric_op,
                         feed_dict={
                             input_ids_ph: test_input_ids,
                             input_mask_ph: test_input_mask,
                             segment_ids_ph: test_segment_ids,
                             label_ids_ph: test_label_ids,
                             is_training_ph: False,
                             keep_prob_ph: 1
                         })
                test_metric_val = sess.run(metric_val)
                preds = sess.run(predictions,
                                 feed_dict={
                                     input_ids_ph: test_input_ids,
                                     input_mask_ph: test_input_mask,
                                     segment_ids_ph: test_segment_ids,
                                     label_ids_ph: test_label_ids,
                                     is_training_ph: False,
                                     keep_prob_ph: 1
                                 })
                test_label_ids_lst = test_label_ids
            else:
                num_batch_test = num_actual_test_examples // batch_size \
                    if num_actual_test_examples % batch_size == 0 \
                    else num_actual_test_examples // batch_size + 1
                id_test = 0

                preds = np.zeros(num_actual_test_examples)
                test_label_ids_lst = np.zeros(num_actual_test_examples)
                sess.run(metric_vars_initializer)
                for i in range(num_batch_test):
                    test_input_ids, test_input_mask, test_segment_ids, \
                    test_label_ids, test_is_real_example = generate_ph_input(batch_size=batch_size,
                                                                             seq_length=seq_length,
                                                                             examples=test_examples,
                                                                             label_list=label_list,
                                                                             tokenizer=tokenizer,
                                                                             train_idx_offset=id_test)
                    id_test += batch_size
                    sess.run(metric_op,
                             feed_dict={
                                 input_ids_ph: test_input_ids,
                                 input_mask_ph: test_input_mask,
                                 segment_ids_ph: test_segment_ids,
                                 label_ids_ph: test_label_ids,
                                 is_training_ph: False,
                                 keep_prob_ph: 1
                             })
                    pred = sess.run(predictions,
                                    feed_dict={
                                        input_ids_ph: test_input_ids,
                                        input_mask_ph: test_input_mask,
                                        segment_ids_ph: test_segment_ids,
                                        label_ids_ph: test_label_ids,
                                        is_training_ph: False,
                                        keep_prob_ph: 1
                                    })
                    preds[i * batch_size:min(id_test, num_actual_test_examples
                                             )] = pred[:]
                    test_label_ids_lst[i * batch_size:min(
                        id_test, num_actual_test_examples)] = test_label_ids[:]
                test_metric_val = sess.run(metric_val)
            for key, val in zip(metric_name, test_metric_val):
                test_metric["test_{}".format(key)] = val

            output_predict_file = os.path.join(log_dir, 'test_predictions.tsv')
            submit_predict_file = os.path.join(
                log_dir, "{}.tsv".format(standard_file_name[task_name]))
            writer_output = tf.io.gfile.GFile(output_predict_file, "w")
            writer_submit = tf.io.gfile.GFile(submit_predict_file, 'w')
            preds = preds.astype(int)
            test_label_ids_lst = test_label_ids_lst.astype(int)

            num_written_lines = 0
            if task_name != 'stsb':
                writer_output.write(
                    "ID \t Label ID \t Label \t Ground Truth ID \t Ground Truth \n"
                )
            else:
                writer_output.write("ID \t Label \n")
            writer_submit.write("ID \t Label \n")
            for (i, pred) in enumerate(preds):
                if task_name != 'stsb':
                    writer_output.write("{} \t {} \t {} \t {} \t {} \n".format(
                        i, pred, label_list[pred], test_label_ids_lst[i],
                        label_list[test_label_ids_lst[i]]))
                    writer_submit.write("{} \t {} \n".format(
                        i, label_list[pred]))
                else:
                    writer_output.write("{} \t {} \n".format(
                        num_written_lines, pred))
                    writer_submit.write("{} \t {} \n".format(i, pred))
            writer_output.close()
            writer_submit.close()
            tf.logging.info("***** Finish writing *****")

        with tf.io.gfile.GFile(FLAGS.output_dir + 'results.txt',
                               'a') as writer:
            eval_start, eval_end, test_end = [], [], []
            for metric in metric_name:
                if metric != 'loss':
                    eval_start.append("{}: %.4f".format(metric) %
                                      start_metric["eval_{}".format(metric)])
                    eval_end.append("{}: %.4f".format(metric) %
                                    end_metric["eval_{}".format(metric)])
                    test_end.append("{}: %.4f".format(metric) %
                                    test_metric["test_{}".format(metric)])

            writer.write(
                "Swap layer order {}: Dev start: {} | Dev end: {} | Test end: {}\n"
                .format(layer_folder_name, ','.join(eval_start),
                        ','.join(eval_end), ','.join(test_end)))
Beispiel #4
0
def main(_):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    processors = {
        "cola": ColaProcessor,
        "mnlim": MnliMProcessor,
        "mnlimm": MnliMMProcessor,
        "mrpc": MrpcProcessor,
        "qnli": QnliProcessor,
        "qqp": QqpProcessor,
        "rte": RteProcessor,
        "sst2": Sst2Processor,
        "stsb": StsbProcessor,
        "wnli": WnliProcessor,
        "ax": AxProcessor,
        "mnlimdevastest": MnliMDevAsTestProcessor,
        "semeval": SemEvalProcessor
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.io.gfile.makedirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()
    print("Current task", task_name)

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    label_list = processor.get_labels()
    print("Label list of current task", label_list)

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    eval_examples = processor.get_dev_examples(FLAGS.data_dir)
    num_actual_eval_examples = len(eval_examples)
    print("num_actual_eval_examples", num_actual_eval_examples)

    n_layers = FLAGS.n_layers
    embed_dim = FLAGS.hidden_size  # hidden size, 768 for BERT-base, 512 for BERT-small
    seq_length = FLAGS.max_seq_length

    # Define some placeholders for the input
    input_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                            shape=[None, seq_length],
                                            name='input_ids')
    input_mask_ph = tf.compat.v1.placeholder(tf.int32,
                                             shape=[None, seq_length],
                                             name='input_mask')
    segment_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                              shape=[None, seq_length],
                                              name='segment_ids')

    tf.compat.v1.logging.info(
        "Running tracking hidden token embeddings through the model!")

    # for faster calculation
    # np.random.seed(0)
    # random_choice_idx = np.random.choice(num_actual_eval_examples, 1000, replace=False)
    # eval_examples = [eval_examples[idx] for idx in random_choice_idx]
    # num_actual_eval_examples = 1000
    if FLAGS.feed_ones:
        num_actual_eval_examples = 1
    print("here we only take {} samples.".format(num_actual_eval_examples))
    tf.compat.v1.logging.info(
        "For faster calculation, we reduce the example size! ")

    save_file = 'track_embeddings.pickle'
    # output = {"word_embeddings": np.zeros((num_actual_eval_examples, seq_length, embed_dim)),
    #           "final_embeddings": np.zeros((num_actual_eval_examples, seq_length, embed_dim)),
    #           "layer_input": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)),
    #           "layer_self_attention": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)),
    #           "layer_self_attention_ff": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)),
    #           "layer_attention_before_ln": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)),
    #           "layer_attention": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)),
    #           "layer_ffgelu": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim * 4)),
    #           "layer_mlp": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)),
    #           "layer_ffn_before_ln": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)),
    #           "layer_output": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)),
    #           "sentence": []}
    output = {
        "final_embeddings":
        np.zeros((num_actual_eval_examples, seq_length, embed_dim)),
        "layer_self_attention_ff":
        np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)),
        "layer_attention":
        np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)),
        "layer_mlp":
        np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)),
        "layer_output":
        np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)),
        "sentence": []
    }

    layers_cancel_skip_connection = []
    if FLAGS.layers_cancel_skip_connection is not None:
        layers_cancel_skip_connection = list(
            map(int, FLAGS.layers_cancel_skip_connection.split(',')))
        layers_cancel_skip_connection.sort()
        print("Layers need to cancel skip-connection: ",
              layers_cancel_skip_connection)

    layers_use_relu = []
    if FLAGS.layers_use_relu is not None:
        layers_use_relu = list(map(int, FLAGS.layers_use_relu.split(',')))
        layers_use_relu.sort()
        print("Layers need to use ReLU: ", layers_use_relu)

    model = modeling.BertModel(
        config=bert_config,
        is_training=False,
        input_ids=input_ids_ph,  # input_ids,
        input_mask=input_mask_ph,  # input_mask,
        token_type_ids=segment_ids_ph,  # segment_ids,
        use_one_hot_embeddings=False,
        cancel_skip_connection=layers_cancel_skip_connection,
        layer_use_relu=layers_use_relu,
        feed_same=FLAGS.feed_same)

    # extract all the intermediate outputs
    # word_embeddings = model.get_in_embeds()[0]
    final_embeddings = model.get_embedding_output()
    # layer_self_attention = model.get_all_head_output()
    layer_self_attention_ff = model.get_all_attention_before_dropout()
    # layer_attention_before_ln = model.get_all_attention_before_layernorm()
    layer_attention = model.get_all_layer_tokens_beforeMLP()
    # layer_ffgelu = model.get_all_layer_tokens_after_FFGeLU()
    layer_mlp = model.get_all_layer_tokens_afterMLP()
    # layer_ffn_before_ln = model.get_all_ffn_before_layernorm()
    layer_output = model.get_all_encoder_layers()

    # load weights
    init_checkpoint = FLAGS.init_checkpoint
    tvars = tf.compat.v1.trainable_variables()
    saver_init = tf.train.Saver(tvars)

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.compat.v1.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver_init.restore(sess, init_checkpoint)

        if FLAGS.feed_ones:
            input_ids = np.zeros([1, seq_length])
            input_mask = np.ones([1, seq_length])
            segment_ids = np.zeros([1, seq_length])
        else:
            input_ids, input_mask, segment_ids, \
            label_ids, is_real_example, input_tokens = generate_ph_input(batch_size=num_actual_eval_examples,
                                                                         seq_length=seq_length,
                                                                         examples=eval_examples,
                                                                         label_list=label_list,
                                                                         tokenizer=tokenizer,
                                                                         return_tokens=True)

        # word_embeddings_val, final_embeddings_val, \
        # layer_self_attention_val, layer_self_attention_ff_val, \
        # layer_attention_before_ln_val, layer_attention_val, \
        # layer_ffgelu_val, layer_mlp_val, layer_ffn_before_ln_val, \
        # layer_output_val = sess.run([word_embeddings, final_embeddings,
        #                              layer_self_attention, layer_self_attention_ff,
        #                              layer_attention_before_ln, layer_attention,
        #                              layer_ffgelu, layer_mlp, layer_ffn_before_ln, layer_output],
        #                             feed_dict={input_ids_ph: input_ids,
        #                                        input_mask_ph: input_mask,
        #                                        segment_ids_ph: segment_ids})

        final_embeddings_val, \
        layer_self_attention_ff_val, layer_attention_val, layer_mlp_val, \
        layer_output_val = sess.run([final_embeddings,
                                     layer_self_attention_ff,
                                     layer_attention,
                                     layer_mlp, layer_output],
                                    feed_dict={input_ids_ph: input_ids,
                                               input_mask_ph: input_mask,
                                               segment_ids_ph: segment_ids})

        # assign values to output dict
        # output["word_embeddings"] = np.reshape(word_embeddings_val,
        #                                        [num_actual_eval_examples, seq_length, embed_dim])
        output["final_embeddings"] = np.reshape(
            final_embeddings_val,
            [num_actual_eval_examples, seq_length, embed_dim])
        for layer in tqdm(range(n_layers)):
            # if FLAGS.feed_same:
            #     output["layer_input"][layer] = np.reshape(final_embeddings_val,
            #                                               [num_actual_eval_examples, seq_length, embed_dim])
            # else:
            #     if layer == 0:
            #         output["layer_input"][layer] = np.reshape(final_embeddings_val,
            #                                                   [num_actual_eval_examples, seq_length, embed_dim])
            #     else:
            #         output["layer_input"][layer] = np.reshape(layer_output_val[layer - 1],
            #                                                   [num_actual_eval_examples, seq_length, embed_dim])
            # output["layer_self_attention"][layer] = np.reshape(layer_self_attention_val[layer],
            #                                                    [num_actual_eval_examples, seq_length, embed_dim])
            output["layer_self_attention_ff"][layer] = np.reshape(
                layer_self_attention_ff_val[layer],
                [num_actual_eval_examples, seq_length, embed_dim])
            # output["layer_attention_before_ln"][layer] = np.reshape(layer_attention_before_ln_val[layer],
            #                                                         [num_actual_eval_examples, seq_length, embed_dim])
            output["layer_attention"][layer] = np.reshape(
                layer_attention_val[layer],
                [num_actual_eval_examples, seq_length, embed_dim])
            # output["layer_ffgelu"][layer] = np.reshape(layer_ffgelu_val[layer],
            #                                            [num_actual_eval_examples, seq_length, embed_dim * 4])
            output["layer_mlp"][layer] = np.reshape(
                layer_mlp_val[layer],
                [num_actual_eval_examples, seq_length, embed_dim])
            # output["layer_ffn_before_ln"][layer] = np.reshape(layer_ffn_before_ln_val[layer],
            #                                                   [num_actual_eval_examples, seq_length, embed_dim])
            output["layer_output"][layer] = np.reshape(
                layer_output_val[layer],
                [num_actual_eval_examples, seq_length, embed_dim])
        if FLAGS.feed_ones:
            output["sentence"] = [[''] * seq_length]
        else:
            output["sentence"] = input_tokens

        joblib.dump(output, os.path.join(FLAGS.output_dir, save_file))
Beispiel #5
0
def main(_):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    processors = {
        "cola": ColaProcessor,
        "mnlim": MnliMProcessor,
        "mnlimm": MnliMMProcessor,
        "mrpc": MrpcProcessor,
        "qnli": QnliProcessor,
        "qqp": QqpProcessor,
        "rte": RteProcessor,
        "sst2": Sst2Processor,
        "stsb": StsbProcessor,
        "wnli": WnliProcessor,
        "ax": AxProcessor,
        "mnlimdevastest": MnliMDevAsTestProcessor
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.io.gfile.makedirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()
    print("Current task", task_name)

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    # special handling for mnlimdevastest
    if task_name == 'mnlimdevastest':
        task_name = 'mnlim'

    label_list = processor.get_labels()
    print("Label list of current task", label_list)

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    train_examples = processor.get_train_examples(FLAGS.data_dir)
    eval_examples = processor.get_dev_examples(FLAGS.data_dir)
    num_actual_train_examples = len(train_examples)
    num_actual_eval_examples = len(eval_examples)
    print("num_actual_train_examples", num_actual_train_examples)
    print("num_actual_eval_examples", num_actual_eval_examples)

    batch_size = FLAGS.train_batch_size
    epochs = FLAGS.num_train_epochs
    embed_dim = FLAGS.hidden_size  # hidden size, 768 for BERT-base, 512 for BERT-small
    seq_length = FLAGS.max_seq_length

    # Define some placeholders for the input
    input_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                            shape=[None, seq_length],
                                            name='input_ids')
    input_mask_ph = tf.compat.v1.placeholder(tf.int32,
                                             shape=[None, seq_length],
                                             name='input_mask')
    segment_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                              shape=[None, seq_length],
                                              name='segment_ids')

    # train an independent linear model to approximate the projection of MLP
    tf.compat.v1.logging.info("Training Approximator!")

    # get the layer(s) which need replacement
    if FLAGS.layers is None:
        raise ValueError(
            "In training non-linearity experiments, layers must not be None. ")
    layer_folder_name = FLAGS.layers
    layers = list(map(int, FLAGS.layers.split(',')))
    layers.sort()
    if len(layers) != 1:
        raise ValueError("Here it allows only one single layer. ")
    approximated_layer = layers[0]
    print("Current approximated layer: ", approximated_layer)
    approximator_setting = FLAGS.approximator_setting

    layers_cancel_skip_connection = []
    if FLAGS.layers_cancel_skip_connection is not None:
        layers_cancel_skip_connection = list(
            map(int, FLAGS.layers_cancel_skip_connection.split(',')))
        layers_cancel_skip_connection.sort()
        print("Layers need to cancel skip-connection: ",
              layers_cancel_skip_connection)

    layers_use_relu = []
    if FLAGS.layers_use_relu is not None:
        layers_use_relu = list(map(int, FLAGS.layers_use_relu.split(',')))
        layers_use_relu.sort()
        print("Layers need to use ReLU: ", layers_use_relu)

    model = modeling.BertModel(
        config=bert_config,
        is_training=False,
        input_ids=input_ids_ph,  # input_ids,
        input_mask=input_mask_ph,  # input_mask,
        token_type_ids=segment_ids_ph,  # segment_ids,
        use_one_hot_embeddings=False,
        approximator_setting=approximator_setting,
        cancel_skip_connection=layers_cancel_skip_connection,
        layer_use_relu=layers_use_relu)

    # define the input and output according to the approximated part
    if FLAGS.approximate_part == 'mlp':
        # only FFGeLU and FF in FFN, without dropout and layernorm
        x = model.get_all_layer_tokens_beforeMLP()[approximated_layer]
        y = model.get_all_layer_tokens_afterMLP()[approximated_layer]
    elif FLAGS.approximate_part == 'ffgelu':
        # only FFGeLU in FFN
        x = model.get_all_layer_tokens_beforeMLP()[approximated_layer]
        y = model.get_all_layer_tokens_after_FFGeLU()[approximated_layer]
    elif FLAGS.approximate_part == 'self_attention':
        # only self-attention part, without linear layer, dropout and layernorm
        x = model.get_all_encoder_layers()[
            approximated_layer -
            1] if approximated_layer > 0 else model.get_embedding_output()
        y = model.get_all_head_output()[approximated_layer]
    elif FLAGS.approximate_part == 'self_attention_ff':
        # only self-attention + linear part part, without dropout and layernorm
        x = model.get_all_encoder_layers()[
            approximated_layer -
            1] if approximated_layer > 0 else model.get_embedding_output()
        y = model.get_all_attention_before_dropout()[approximated_layer]
    elif FLAGS.approximate_part == 'ff_after_self_attention':
        # only the linear layer after self-attention
        x = model.get_all_head_output()[approximated_layer]
        y = model.get_all_attention_before_dropout()[approximated_layer]
    elif FLAGS.approximate_part == 'attention_before_ln':
        # only self-attention + linear part part, before layernorm
        x = model.get_all_encoder_layers()[
            approximated_layer -
            1] if approximated_layer > 0 else model.get_embedding_output()
        y = model.get_all_attention_before_layernorm()[approximated_layer]
    elif FLAGS.approximate_part == 'attention':
        # whole attention block including dropout and layernorm
        x = model.get_all_encoder_layers()[
            approximated_layer -
            1] if approximated_layer > 0 else model.get_embedding_output()
        y = model.get_all_layer_tokens_beforeMLP()[approximated_layer]
    elif FLAGS.approximate_part == 'ffn':
        # whole FFN block including dropout and layernorm
        x = model.get_all_layer_tokens_beforeMLP()[approximated_layer]
        y = model.get_all_encoder_layers()[approximated_layer]
    elif FLAGS.approximate_part == 'ffn_before_ln':
        # only FFGeLU and FF in FFN, before layernorm
        x = model.get_all_layer_tokens_beforeMLP()[approximated_layer]
        y = model.get_all_ffn_before_layernorm()[approximated_layer]
    elif FLAGS.approximate_part == 'encoder':
        # whole encoder including dropout and layernorm
        x = model.get_all_encoder_layers()[
            approximated_layer -
            1] if approximated_layer > 0 else model.get_embedding_output()
        y = model.get_all_encoder_layers()[approximated_layer]
    else:
        raise ValueError("Need to specify correct value. ")

    if FLAGS.approximator_setting in [
            'HS_MLP', 'HS*4+HS_MLP', 'HS_Self_Attention',
            'HS_Self_Attention_FF', 'HS_FFN', 'HS_Attention', 'HS_Encoder',
            'HS_Attention_Before_LN', 'HS_FFN_Before_LN',
            'HS_FF_After_Self_Attention'
    ]:
        approximator_dim = embed_dim
    else:  # HS*4_FFGeLU
        approximator_dim = embed_dim * 4

    x = tf.reshape(x, [-1, seq_length, embed_dim])
    y = tf.reshape(y, [-1, seq_length, approximator_dim])

    # non-linearity
    if not FLAGS.use_nonlinear_approximator:
        y_pred = approximator.linear_approximator(
            input=x,
            approximated_layer=approximated_layer,
            hidden_size=approximator_dim)
    else:
        y_pred = approximator.nonlinear_approximator(
            input=x,
            approximated_layer=approximated_layer,
            hidden_size=approximator_dim,
            num_layer=1,
            use_dropout=FLAGS.use_dropout,
            dropout_p=0.2)

    # with tf.compat.v1.variable_scope("bert/encoder/layer_{}/non-linearity".format(approximated_layer)):
    #     y_pred = tf.layers.dense(
    #         x,
    #         approximator_dim,
    #         kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))

    # NOTE, here is alias of cosine_proximity, [0, 1]. Not the TF2.0 version, which is [-1. 0].
    similarity_token = tf.keras.losses.cosine_similarity(y, y_pred, axis=-1)
    l2_distance_token = tf.norm(y - y_pred, ord=2, axis=-1)
    # loss: similarity or mean squared error or l2 norm
    # when using cosine similarity, the target becomes minimizing to 0.
    per_example_loss = 1. - tf.reduce_mean(similarity_token, axis=-1) \
        if FLAGS.loss == 'cosine' \
        else (tf.reduce_mean(tf.keras.losses.mse(y, y_pred), axis=-1)
              if FLAGS.loss == 'mse'
              else (tf.reduce_mean(l2_distance_token, axis=-1)
                    if FLAGS.loss == 'l2'
                    else 0.0))

    loss = tf.reduce_mean(per_example_loss, axis=-1)
    similarity = tf.reduce_mean(tf.reduce_mean(similarity_token, axis=-1),
                                axis=-1)
    l2_distance = tf.reduce_mean(tf.reduce_mean(l2_distance_token, axis=-1),
                                 axis=-1)

    # for summary
    loss_ph = tf.compat.v1.placeholder(tf.float32, name='loss')
    similarity_ph = tf.compat.v1.placeholder(tf.float32, name='similarity')
    l2_distance_ph = tf.compat.v1.placeholder(tf.float32, name='l2_distance')
    loss_sum = tf.summary.scalar("loss", loss_ph)
    similarity_sum = tf.summary.scalar("similarity", similarity_ph)
    l2_distance_sum = tf.summary.scalar("l2_distance", l2_distance_ph)
    train_summary_total = tf.summary.merge(
        [loss_sum, similarity_sum, l2_distance_sum])
    eval_summary_total = tf.summary.merge(
        [loss_sum, similarity_sum, l2_distance_sum])

    log_dir = FLAGS.output_dir + 'layer_{}/'.format(layer_folder_name)

    # define optimizer
    optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)

    # load weights
    init_checkpoint = FLAGS.init_checkpoint
    tvars = tf.compat.v1.trainable_variables()
    var_init = [var for var in tvars if "non-linearity" not in var.name]
    saver_init = tf.train.Saver(var_init)

    # take variables related to approximators
    var_approximator = [var for var in tvars if "non-linearity" in var.name]
    print("Training variables")
    for var in var_approximator:
        print(var)

    # define optimizer op and saver
    optimizer_op = optimizer.minimize(loss=loss, var_list=var_approximator)
    saver_approximator = tf.train.Saver(var_list=var_approximator,
                                        max_to_keep=1)

    # add this GPU settings
    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.compat.v1.Session(config=config) as sess:
        sess.run(tf.compat.v1.global_variables_initializer())
        saver_init.restore(sess, init_checkpoint)

        writer = tf.compat.v1.summary.FileWriter(log_dir + 'log/train/',
                                                 sess.graph)
        writer_eval = tf.compat.v1.summary.FileWriter(log_dir + 'log/eval/')

        # if number of eval examples < 1000, just load it directly, or load by batch.
        if num_actual_eval_examples <= 1000:
            eval_input_ids, eval_input_mask, eval_segment_ids, \
            eval_label_ids, eval_is_real_example, \
            eval_input_tokens = generate_ph_input(batch_size=num_actual_eval_examples,
                                                  seq_length=seq_length,
                                                  examples=eval_examples,
                                                  label_list=label_list,
                                                  tokenizer=tokenizer,
                                                  return_tokens=True)

        step = 1
        for n in range(epochs):

            np.random.shuffle(train_examples)
            num_batch = num_actual_train_examples // batch_size if num_actual_train_examples % batch_size == 0 \
                else num_actual_train_examples // batch_size + 1
            id = 0

            for b in range(num_batch):

                input_ids, input_mask, \
                segment_ids, label_ids, is_real_example = generate_ph_input(batch_size=batch_size,
                                                                            seq_length=seq_length,
                                                                            examples=train_examples,
                                                                            label_list=label_list,
                                                                            tokenizer=tokenizer,
                                                                            train_idx_offset=id)
                id += batch_size

                train_loss, train_similarity, train_l2_distance, _ = sess.run(
                    [loss, similarity, l2_distance, optimizer_op],
                    feed_dict={
                        input_ids_ph: input_ids,
                        input_mask_ph: input_mask,
                        segment_ids_ph: segment_ids
                    })
                train_summary_str = sess.run(train_summary_total,
                                             feed_dict={
                                                 loss_ph: train_loss,
                                                 similarity_ph:
                                                 train_similarity,
                                                 l2_distance_ph:
                                                 train_l2_distance
                                             })
                writer.add_summary(train_summary_str, step)

                # evaluate on dev set
                if step % 100 == 0 or step % num_batch == 0 or step == 1:

                    if num_actual_eval_examples <= 1000:
                        eval_loss, eval_similarity, eval_l2_distance = sess.run(
                            [loss, similarity, l2_distance],
                            feed_dict={
                                input_ids_ph: eval_input_ids,
                                input_mask_ph: eval_input_mask,
                                segment_ids_ph: eval_segment_ids
                            })
                        eval_summary_str = sess.run(eval_summary_total,
                                                    feed_dict={
                                                        loss_ph:
                                                        eval_loss,
                                                        similarity_ph:
                                                        eval_similarity,
                                                        l2_distance_ph:
                                                        eval_l2_distance
                                                    })

                    else:

                        num_batch_eval = num_actual_eval_examples // batch_size \
                            if num_actual_eval_examples % batch_size == 0 \
                            else num_actual_eval_examples // batch_size + 1
                        id_eval = 0
                        acc_loss, acc_similarity, acc_l2_distance = 0, 0, 0

                        for _ in range(num_batch_eval):
                            eval_input_ids, eval_input_mask, eval_segment_ids, \
                            eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size,
                                                                                     seq_length=seq_length,
                                                                                     examples=eval_examples,
                                                                                     label_list=label_list,
                                                                                     tokenizer=tokenizer,
                                                                                     train_idx_offset=id_eval)
                            id_eval += batch_size

                            eval_loss, eval_similarity, eval_l2_distance = sess.run(
                                [loss, similarity, l2_distance],
                                feed_dict={
                                    input_ids_ph: eval_input_ids,
                                    input_mask_ph: eval_input_mask,
                                    segment_ids_ph: eval_segment_ids
                                })
                            acc_loss += eval_loss
                            acc_similarity += eval_similarity
                            acc_l2_distance += eval_l2_distance

                        eval_loss = acc_loss / num_batch_eval
                        eval_similarity = acc_similarity / num_batch_eval
                        eval_l2_distance = acc_l2_distance / num_batch_eval
                        eval_summary_str = sess.run(eval_summary_total,
                                                    feed_dict={
                                                        loss_ph:
                                                        eval_loss,
                                                        similarity_ph:
                                                        eval_similarity,
                                                        l2_distance_ph:
                                                        eval_l2_distance
                                                    })
                    writer_eval.add_summary(eval_summary_str, step)

                if step % 100 == 0 or step % num_batch == 0 or step == 1:
                    print(
                        "Approximating layer: %2d | Epoch: %4d/%4d | Batch: %4d/%4d | "
                        "Train loss: %2.4f | Train similarity/L2 distance: %.4f/%2.4f | "
                        "Eval loss: %2.4f | Eval similarity/L2 distance: %.4f/%2.4f"
                        % (approximated_layer, n, epochs, b, num_batch,
                           train_loss, train_similarity, train_l2_distance,
                           eval_loss, eval_similarity, eval_l2_distance))

                if step % num_batch == 0:
                    saver_approximator.save(
                        sess,
                        log_dir + 'approximator_{}'.format(layer_folder_name),
                        global_step=step)

                step += 1

        writer.close()
        writer_eval.close()

        # eval
        eval_input_ids, eval_input_mask, eval_segment_ids, \
        eval_label_ids, eval_is_real_example, \
        eval_input_tokens = generate_ph_input(batch_size=num_actual_eval_examples,
                                              seq_length=seq_length,
                                              examples=eval_examples,
                                              label_list=label_list,
                                              tokenizer=tokenizer,
                                              return_tokens=True)
        eval_token_length = [len(elem) for elem in eval_input_tokens]
        ground_truths, predictions, cosine_similarity_all = sess.run(
            [y, y_pred, similarity],
            feed_dict={
                input_ids_ph: eval_input_ids,
                input_mask_ph: eval_input_mask,
                segment_ids_ph: eval_segment_ids
            })
        cosine_similarity_actual_token = 0
        cosine_similarity_padding = 0
        cosine_similarity_all = 0
        num_sample_no_padding = 0
        for i in range(num_actual_eval_examples):
            if eval_token_length[i] >= seq_length:
                num_sample_no_padding += 1
                continue
            cosine_similarity_actual_token += np.mean([
                1 - cosine(ground_truths[i][j], predictions[i][j])
                for j in range(eval_token_length[i])
            ])
            cosine_similarity_padding += np.mean([
                1 - cosine(ground_truths[i][j], predictions[i][j])
                for j in range(eval_token_length[i], seq_length)
            ])
            cosine_similarity_all += np.mean([
                1 - cosine(ground_truths[i][j], predictions[i][j])
                for j in range(seq_length)
            ])
        cosine_similarity_actual_token /= (num_actual_eval_examples -
                                           num_sample_no_padding)
        cosine_similarity_padding /= (num_actual_eval_examples -
                                      num_sample_no_padding)
        cosine_similarity_all /= (num_actual_eval_examples -
                                  num_sample_no_padding)

        print("Skip {} samples without paddings".format(num_sample_no_padding))

        with tf.io.gfile.GFile(FLAGS.output_dir + 'results.txt',
                               'a') as writer:
            writer.write(
                "Approximating layer %2d: Actual: %.4f/%.4f | Pad: %.4f/%.4f | All: %.4f/%.4f\n"
                % (approximated_layer, cosine_similarity_actual_token,
                   1 - cosine_similarity_actual_token,
                   cosine_similarity_padding, 1 - cosine_similarity_padding,
                   cosine_similarity_all, 1 - cosine_similarity_all))
Beispiel #6
0
def main(_):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    processors = {
        "cola": ColaProcessor,
        "mnlim": MnliMProcessor,
        "mnlimm": MnliMMProcessor,
        "mrpc": MrpcProcessor,
        "qnli": QnliProcessor,
        "qqp": QqpProcessor,
        "rte": RteProcessor,
        "sst2": Sst2Processor,
        "stsb": StsbProcessor,
        "wnli": WnliProcessor,
        "ax": AxProcessor,
        "mnlimdevastest": MnliMDevAsTestProcessor
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    if not FLAGS.do_train:
        raise ValueError("At least 'do_train' must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.io.gfile.makedirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()
    print("Current task", task_name)

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    # special handling for mnlimdevastest
    if task_name == 'mnlimdevastest':
        task_name = 'mnlim'

    label_list = processor.get_labels()
    print("Label list of current task", label_list)

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    train_examples = processor.get_train_examples(FLAGS.data_dir)
    eval_examples = processor.get_dev_examples(FLAGS.data_dir)
    num_actual_train_examples = len(train_examples)
    num_actual_eval_examples = len(eval_examples)
    print("num_actual_train_examples", num_actual_train_examples)
    print("num_actual_eval_examples", num_actual_eval_examples)

    batch_size = FLAGS.train_batch_size
    epochs = FLAGS.num_train_epochs
    embed_dim = FLAGS.hidden_size  # hidden size, 768 for BERT-base, 512 for BERT-small
    seq_length = FLAGS.max_seq_length
    num_labels = len(label_list)

    # Define some placeholders for the input
    input_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                            shape=[None, seq_length],
                                            name='input_ids')
    input_mask_ph = tf.compat.v1.placeholder(tf.int32,
                                             shape=[None, seq_length],
                                             name='input_mask')
    segment_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                              shape=[None, seq_length],
                                              name='segment_ids')
    label_ids_ph = tf.compat.v1.placeholder(tf.int32,
                                            shape=[
                                                None,
                                            ],
                                            name='label_ids')

    tf.compat.v1.logging.info(
        "Running leave the most important head per layer then fine-tune!")

    num_train_steps = num_actual_train_examples // batch_size * epochs
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    cur_layer = FLAGS.cur_layer
    most_important_head = FLAGS.most_important_head
    print("Current most important head:", cur_layer, most_important_head)

    # this placeholder is to control the flag for the dropout
    keep_prob_ph = tf.compat.v1.placeholder(tf.float32, name="keep_prob")
    is_training_ph = tf.compat.v1.placeholder(tf.bool, name='is_training')

    # two placeholders for the head coordinates, layer, head
    head_mask_ph = tf.compat.v1.placeholder(tf.int32,
                                            shape=[
                                                None,
                                            ],
                                            name='head_mask')
    layer_mask_ph = tf.compat.v1.placeholder(tf.int32,
                                             shape=[
                                                 None,
                                             ],
                                             name='layer_mask')

    model = modeling.BertModel(
        config=bert_config,
        is_training=is_training_ph,
        input_ids=input_ids_ph,  # input_ids,
        input_mask=input_mask_ph,  # input_mask,
        token_type_ids=segment_ids_ph,  # segment_ids,
        use_one_hot_embeddings=False,
        use_estimator=False,
        head_mask=head_mask_ph,
        layer_mask=layer_mask_ph)

    output_layer = model.get_pooled_output()
    output_layer = tf.nn.dropout(output_layer, keep_prob=keep_prob_ph)
    output_weights = tf.get_variable(
        "output_weights", [num_labels, embed_dim],
        initializer=tf.truncated_normal_initializer(stddev=0.02))
    output_bias = tf.get_variable("output_bias", [num_labels],
                                  initializer=tf.zeros_initializer())
    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)

    with tf.compat.v1.variable_scope("loss"):
        # for stsb
        if num_labels == 1:
            logits = tf.squeeze(logits, [-1])
            per_example_loss = tf.square(logits - label_ids_ph)
            loss = tf.reduce_mean(per_example_loss)
        else:
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids_ph,
                                        depth=num_labels,
                                        dtype=tf.float32)
            per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                              axis=-1)
            loss = tf.reduce_mean(per_example_loss)
            predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)

    # metric and summary
    # metric is tf.metric object, (val, op)
    metric = metric_fn(per_example_loss, label_ids_ph, logits, num_labels,
                       task_name)
    metric_name = list(metric.keys())
    metric_val = [m[0] for m in metric.values()]
    metric_op = [m[1] for m in metric.values()]

    log_dir = FLAGS.output_dir + 'layer_{}_head_{}/'.format(
        cur_layer, most_important_head)

    metric_phs = [
        tf.compat.v1.placeholder(tf.float32, name="{}_ph".format(key))
        for key in metric.keys()
    ]
    summaries = [
        tf.compat.v1.summary.scalar(key, metric_phs[i])
        for i, key in enumerate(metric.keys())
    ]
    train_summary_total = tf.summary.merge(summaries)
    eval_summary_total = tf.summary.merge(summaries)

    init_checkpoint = FLAGS.init_checkpoint
    tvars = tf.compat.v1.trainable_variables()
    var_init = [
        v for v in tvars
        if 'output_weights' not in v.name and 'output_bias' not in v.name
    ]
    var_output = [
        v for v in tvars
        if 'output_weights' in v.name or "output_bias" in v.name
    ]

    if not FLAGS.load_from_finetuned:
        # Init from Model0
        saver_init = tf.train.Saver(var_init)
    else:
        # Init from Model1
        saver_init = tf.train.Saver(var_init + var_output)

    var_train = var_init + var_output
    print("Training parameters")
    for v in var_train:
        print(v)

    train_op = optimization.create_optimizer(loss=loss,
                                             init_lr=FLAGS.learning_rate,
                                             num_train_steps=num_train_steps,
                                             num_warmup_steps=num_warmup_steps,
                                             use_tpu=False,
                                             tvars=var_train)

    saver_all = tf.train.Saver(var_list=var_init + var_output, max_to_keep=1)

    # Isolate the variables stored behind the scenes by the metric operation
    var_metric = []
    for key in metric.keys():
        var_metric.extend(
            tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=key))
    # Define initializer to initialize/reset running variables
    metric_vars_initializer = tf.variables_initializer(var_list=var_metric)

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.compat.v1.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver_init.restore(sess, init_checkpoint)

        writer = tf.compat.v1.summary.FileWriter(log_dir + 'log/train/',
                                                 sess.graph)
        writer_eval = tf.compat.v1.summary.FileWriter(log_dir + 'log/eval/')

        # heads need to be masked out in cur_layer
        head_mask = [head for head in range(12) if head != most_important_head]
        layer_mask = [cur_layer for _ in range(11)]

        # if number of eval examples < 1000, just load it directly, or load by batch.
        if num_actual_eval_examples <= 1000:
            eval_input_ids, eval_input_mask, eval_segment_ids, \
            eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=num_actual_eval_examples,
                                                                     seq_length=seq_length,
                                                                     examples=eval_examples,
                                                                     label_list=label_list,
                                                                     tokenizer=tokenizer)

        start_metric = {"eval_{}".format(key): 0 for key in metric_name}
        end_metric = {"eval_{}".format(key): 0 for key in metric_name}

        if FLAGS.do_train:
            tf.logging.info("***** Run training *****")
            step = 1
            for n in range(epochs):

                np.random.shuffle(train_examples)
                num_batch = num_actual_train_examples // batch_size if num_actual_train_examples % batch_size == 0 \
                    else num_actual_train_examples // batch_size + 1
                id = 0

                for b in range(num_batch):

                    input_ids, input_mask, \
                    segment_ids, label_ids, is_real_example = generate_ph_input(batch_size=batch_size,
                                                                                seq_length=seq_length,
                                                                                examples=train_examples,
                                                                                label_list=label_list,
                                                                                tokenizer=tokenizer,
                                                                                train_idx_offset=id)
                    id += batch_size
                    sess.run(metric_vars_initializer)
                    sess.run([train_op] + metric_op,
                             feed_dict={
                                 input_ids_ph: input_ids,
                                 input_mask_ph: input_mask,
                                 segment_ids_ph: segment_ids,
                                 label_ids_ph: label_ids,
                                 is_training_ph: True,
                                 keep_prob_ph: 0.9,
                                 head_mask_ph: head_mask,
                                 layer_mask_ph: layer_mask
                             })
                    train_metric_val = sess.run(metric_val)
                    train_summary_str = sess.run(
                        train_summary_total,
                        feed_dict={
                            ph: value
                            for ph, value in zip(metric_phs, train_metric_val)
                        })
                    writer.add_summary(train_summary_str, step)

                    if step % 100 == 0 or step % num_batch == 0 or step == 1:
                        # evaluate on dev set

                        if num_actual_eval_examples <= 1000:
                            sess.run(metric_vars_initializer)
                            sess.run(metric_op,
                                     feed_dict={
                                         input_ids_ph: eval_input_ids,
                                         input_mask_ph: eval_input_mask,
                                         segment_ids_ph: eval_segment_ids,
                                         label_ids_ph: eval_label_ids,
                                         is_training_ph: False,
                                         keep_prob_ph: 1,
                                         head_mask_ph: head_mask,
                                         layer_mask_ph: layer_mask
                                     })
                            eval_metric_val = sess.run(metric_val)
                            eval_summary_str = sess.run(
                                eval_summary_total,
                                feed_dict={
                                    ph: value
                                    for ph, value in zip(
                                        metric_phs, eval_metric_val)
                                })
                        else:
                            num_batch_eval = num_actual_eval_examples // batch_size \
                                if num_actual_eval_examples % batch_size == 0 \
                                else num_actual_eval_examples // batch_size + 1
                            id_eval = 0

                            sess.run(metric_vars_initializer)
                            for _ in range(num_batch_eval):
                                eval_input_ids, eval_input_mask, eval_segment_ids, \
                                eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size,
                                                                                         seq_length=seq_length,
                                                                                         examples=eval_examples,
                                                                                         label_list=label_list,
                                                                                         tokenizer=tokenizer,
                                                                                         train_idx_offset=id_eval)
                                id_eval += batch_size
                                sess.run(metric_op,
                                         feed_dict={
                                             input_ids_ph: eval_input_ids,
                                             input_mask_ph: eval_input_mask,
                                             segment_ids_ph: eval_segment_ids,
                                             label_ids_ph: eval_label_ids,
                                             is_training_ph: False,
                                             keep_prob_ph: 1,
                                             head_mask_ph: head_mask,
                                             layer_mask_ph: layer_mask
                                         })

                            eval_metric_val = sess.run(metric_val)
                            eval_summary_str = sess.run(
                                eval_summary_total,
                                feed_dict={
                                    ph: value
                                    for ph, value in zip(
                                        metric_phs, eval_metric_val)
                                })

                        writer_eval.add_summary(eval_summary_str, step)

                        if step == 1:
                            for key, val in zip(metric_name, eval_metric_val):
                                start_metric["eval_{}".format(key)] = val
                        if step == epochs * num_batch:
                            for key, val in zip(metric_name, eval_metric_val):
                                end_metric["eval_{}".format(key)] = val

                    if step % 100 == 0 or step % num_batch == 0 or step == 1:
                        train_metric_list = []
                        for i in range(len(train_metric_val)):
                            if metric_name[i] == 'loss':
                                train_metric_list.append(
                                    "{}: %2.4f".format(metric_name[i]) %
                                    train_metric_val[i])
                            else:
                                train_metric_list.append(
                                    "{}: %.4f".format(metric_name[i]) %
                                    train_metric_val[i])
                        train_str = 'Train ' + '|'.join(train_metric_list)

                        eval_metric_list = []
                        for i in range(len(eval_metric_val)):
                            if metric_name[i] == 'loss':
                                eval_metric_list.append(
                                    "{}: %2.4f".format(metric_name[i]) %
                                    eval_metric_val[i])
                            else:
                                eval_metric_list.append(
                                    "{}: %.4f".format(metric_name[i]) %
                                    eval_metric_val[i])
                        eval_str = 'Eval ' + '|'.join(eval_metric_list)

                        print(
                            "Layer {}, leave only one head {} | Epoch: %4d/%4d | Batch: %4d/%4d | {} | {}"
                            .format(cur_layer, most_important_head, train_str,
                                    eval_str) % (n, epochs, b, num_batch))

                    if step % num_batch == 0:
                        saver_all.save(sess,
                                       log_dir + 'layer{}_head_{}'.format(
                                           cur_layer, most_important_head),
                                       global_step=step)

                    step += 1

            writer.close()
            writer_eval.close()

        print("Start metric", start_metric)
        print("End metric", end_metric)

        with tf.io.gfile.GFile(FLAGS.output_dir + 'results.txt',
                               'a') as writer:
            eval_start, eval_end = [], []
            for metric in metric_name:
                if metric != 'loss':
                    eval_start.append("{}: %.4f".format(metric) %
                                      start_metric["eval_{}".format(metric)])
                    eval_end.append("{}: %.4f".format(metric) %
                                    end_metric["eval_{}".format(metric)])

            writer.write(
                "Layer {}, leave only one head {}: Start: {} | End: {}\n".
                format(cur_layer, most_important_head, ','.join(eval_start),
                       ','.join(eval_end)))