def main(_): tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) processors = { "cola": ColaProcessor, "mnlim": MnliMProcessor, "mnlimm": MnliMMProcessor, "mrpc": MrpcProcessor, "qnli": QnliProcessor, "qqp": QqpProcessor, "rte": RteProcessor, "sst2": Sst2Processor, "stsb": StsbProcessor, "wnli": WnliProcessor, "ax": AxProcessor, "mnlimdevastest": MnliMDevAsTestProcessor } tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, FLAGS.init_checkpoint) if not FLAGS.do_eval: raise ValueError("At least 'do_eval' must be True.") bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) if FLAGS.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (FLAGS.max_seq_length, bert_config.max_position_embeddings)) tf.io.gfile.makedirs(FLAGS.output_dir) task_name = FLAGS.task_name.lower() print("Current task", task_name) if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() # special handling for mnlimdevastest if task_name == 'mnlimdevastest': task_name = 'mnlim' label_list = processor.get_labels() print("Label list of current task", label_list) tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) eval_examples = processor.get_dev_examples(FLAGS.data_dir) num_actual_eval_examples = len(eval_examples) print("num_actual_eval_examples", num_actual_eval_examples) batch_size = FLAGS.eval_batch_size embed_dim = FLAGS.hidden_size # hidden size, 768 for BERT-base, 512 for BERT-small seq_length = FLAGS.max_seq_length num_labels = len(label_list) # Define some placeholders for the input input_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_ids') input_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_mask') segment_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='segment_ids') label_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[ None, ], name='label_ids') tf.compat.v1.logging.info( "Running single-head masking out and direct evaluation!") # we want to mask out the individual head and then evaluate. So there are 12 layers * 12 heads results. n_layers = 12 n_heads = 12 folder = FLAGS.output_dir save_file = 'single_head_mask.pickle' output = np.zeros((n_layers, n_heads)) # two placeholders for the head coordinates, layer, head head_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[ None, ], name='head_mask') layer_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[ None, ], name='layer_mask') model = modeling.BertModel( config=bert_config, is_training=False, input_ids=input_ids_ph, # input_ids, input_mask=input_mask_ph, # input_mask, token_type_ids=segment_ids_ph, # segment_ids, use_one_hot_embeddings=False, head_mask=head_mask_ph, layer_mask=layer_mask_ph) output_layer = model.get_pooled_output() output_weights = tf.get_variable( "output_weights", [num_labels, embed_dim], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable("output_bias", [num_labels], initializer=tf.zeros_initializer()) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) with tf.compat.v1.variable_scope("loss"): # for stsb if num_labels == 1: logits = tf.squeeze(logits, [-1]) per_example_loss = tf.square(logits - label_ids_ph) loss = tf.reduce_mean(per_example_loss) else: log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids_ph, depth=num_labels, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) loss = tf.reduce_mean(per_example_loss) predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) # metric and summary # metric is tf.metric object, (val, op) metric = metric_fn(per_example_loss, label_ids_ph, logits, num_labels, task_name) metric_name = list(metric.keys()) metric_val = [m[0] for m in metric.values()] metric_op = [m[1] for m in metric.values()] init_checkpoint = FLAGS.init_checkpoint tvars = tf.compat.v1.trainable_variables() saver_init = tf.train.Saver(tvars) # Isolate the variables stored behind the scenes by the metric operation var_metric = [] for key in metric.keys(): var_metric.extend( tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=key)) # Define initializer to initialize/reset running variables metric_vars_initializer = tf.variables_initializer(var_list=var_metric) config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True with tf.compat.v1.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver_init.restore(sess, init_checkpoint) # if number of eval examples < 1000, just load it directly, or load by batch. if num_actual_eval_examples <= 1000: eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=num_actual_eval_examples, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer) # loop over layers, then loop over heads for l in range(n_layers): for h in range(n_heads): cur_l, cur_h = l, h head_mask = [h] layer_mask = [l] # if number of eval examples < 1000, just load it directly, or load by batch. if num_actual_eval_examples <= 1000: sess.run(metric_vars_initializer) sess.run(metric_op, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, head_mask_ph: head_mask, layer_mask_ph: layer_mask }) eval_metric_val = sess.run(metric_val) else: num_batch_eval = num_actual_eval_examples // batch_size \ if num_actual_eval_examples % batch_size == 0 \ else num_actual_eval_examples // batch_size + 1 id_eval = 0 sess.run(metric_vars_initializer) for _ in range(num_batch_eval): eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer, train_idx_offset=id_eval) id_eval += batch_size sess.run(metric_op, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, head_mask_ph: head_mask, layer_mask_ph: layer_mask }) eval_metric_val = sess.run(metric_val) for name, val in zip(metric_name, eval_metric_val): if name == 'accuracy': output[cur_l][cur_h] = val print( "Mask out the head in (Layer {}, Head {}) | {}: {}" .format(cur_l, cur_h, name, val)) joblib.dump(output, folder + save_file)
def main(_): tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) processors = { "cola": ColaProcessor, "mnlim": MnliMProcessor, "mnlimm": MnliMMProcessor, "mrpc": MrpcProcessor, "qnli": QnliProcessor, "qqp": QqpProcessor, "rte": RteProcessor, "sst2": Sst2Processor, "stsb": StsbProcessor, "wnli": WnliProcessor, "ax": AxProcessor } tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, FLAGS.init_checkpoint) if not FLAGS.do_eval: raise ValueError("At least 'do_eval' must be True.") bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) if FLAGS.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (FLAGS.max_seq_length, bert_config.max_position_embeddings)) tf.io.gfile.makedirs(FLAGS.output_dir) task_name = FLAGS.task_name.lower() print("Current task", task_name) if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() label_list = processor.get_labels() print("Label list of current task", label_list) tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) train_examples = processor.get_train_examples(FLAGS.data_dir) eval_examples = processor.get_dev_examples(FLAGS.data_dir) num_actual_train_examples = len(train_examples) num_actual_eval_examples = len(eval_examples) print("num_actual_train_examples", num_actual_train_examples) print("num_actual_eval_examples", num_actual_eval_examples) batch_size = FLAGS.train_batch_size embed_dim = FLAGS.hidden_size # hidden size, 768 for BERT-base, 512 for BERT-small seq_length = FLAGS.max_seq_length num_labels = len(label_list) # Define some placeholders for the input input_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_ids') input_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_mask') segment_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='segment_ids') label_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, ], name='label_ids') tf.compat.v1.logging.info("Running cumulatively mask out heads and direct evaluation!") # we want to mask out the individual head and then evaluate. So there are 12 layers * 12 heads results. n_layers = 12 n_heads = 12 folder = FLAGS.output_dir save_file = 'cumulative_heads_mask.pickle' output = np.zeros((n_layers, n_heads)) importance_coordinates = [] if FLAGS.importance_setting == 'l2_norm': ####################### # 12 * 12, layer * head, importance increases from 1 to 144. # [[127, 63, 67, 72, 91, 93, 124, 100, 96, 134, 15, 133], # [143, 60, 107, 128, 57, 106, 118, 83, 144, 135, 99, 116], # [30, 111, 115, 132, 112, 94, 122, 66, 123, 40, 75, 89], # [108, 140, 141, 119, 121, 62, 137, 131, 125, 70, 85, 105], # [126, 120, 113, 136, 92, 79, 110, 74, 103, 84, 86, 53], # [69, 87, 117, 80, 142, 114, 129, 104, 97, 18, 52, 77], # [81, 88, 48, 90, 56, 50, 58, 101, 130, 64, 35, 46], # [20, 41, 38, 32, 71, 59, 82, 43, 78, 55, 47, 37], # [24, 95, 27, 44, 65, 12, 28, 102, 98, 23, 14, 19], # [17, 16, 2, 1, 9, 13, 68, 4, 139, 7, 21, 109], # [76, 26, 8, 138, 10, 29, 31, 54, 6, 36, 3, 49], # [61, 51, 42, 45, 34, 5, 73, 33, 25, 22, 39, 11]] ####################### # sorted head coordinate array by importance (L2 weight magnitude), (layer, head) # least 1 --> most 144 importance_coordinates = [[9, 3], [9, 2], [10, 10], [9, 7], [11, 5], [10, 8], [9, 9], [10, 2], [9, 4], [10, 4], [11, 11], [8, 5], [9, 5], [8, 10], [0, 10], [9, 1], [9, 0], [5, 9], [8, 11], [7, 0], [9, 10], [11, 9], [8, 9], [8, 0], [11, 8], [10, 1], [8, 2], [8, 6], [10, 5], [2, 0], [10, 6], [7, 3], [11, 7], [11, 4], [6, 10], [10, 9], [7, 11], [7, 2], [11, 10], [2, 9], [7, 1], [11, 2], [7, 7], [8, 3], [11, 3], [6, 11], [7, 10], [6, 2], [10, 11], [6, 5], [11, 1], [5, 10], [4, 11], [10, 7], [7, 9], [6, 4], [1, 4], [6, 6], [7, 5], [1, 1], [11, 0], [3, 5], [0, 1], [6, 9], [8, 4], [2, 7], [0, 2], [9, 6], [5, 0], [3, 9], [7, 4], [0, 3], [11, 6], [4, 7], [2, 10], [10, 0], [5, 11], [7, 8], [4, 5], [5, 3], [6, 0], [7, 6], [1, 7], [4, 9], [3, 10], [4, 10], [5, 1], [6, 1], [2, 11], [6, 3], [0, 4], [4, 4], [0, 5], [2, 5], [8, 1], [0, 8], [5, 8], [8, 8], [1, 10], [0, 7], [6, 7], [8, 7], [4, 8], [5, 7], [3, 11], [1, 5], [1, 2], [3, 0], [9, 11], [4, 6], [2, 1], [2, 4], [4, 2], [5, 5], [2, 2], [1, 11], [5, 2], [1, 6], [3, 3], [4, 1], [3, 4], [2, 6], [2, 8], [0, 6], [3, 8], [4, 0], [0, 0], [1, 3], [5, 6], [6, 8], [3, 7], [2, 3], [0, 11], [0, 9], [1, 9], [4, 3], [3, 6], [10, 3], [9, 8], [3, 1], [3, 2], [5, 4], [1, 0], [1, 8]] elif FLAGS.importance_setting == 'per_head_score': ####################### # 12 * 12, layer * head, importance increases from 1 to 144. # [[127, 63, 67, 72, 91, 93, 124, 100, 96, 134, 15, 133], # [143, 60, 107, 128, 57, 106, 118, 83, 144, 135, 99, 116], # [30, 111, 115, 132, 112, 94, 122, 66, 123, 40, 75, 89], # [108, 140, 141, 119, 121, 62, 137, 131, 125, 70, 85, 105], # [126, 120, 113, 136, 92, 79, 110, 74, 103, 84, 86, 53], # [69, 87, 117, 80, 142, 114, 129, 104, 97, 18, 52, 77], # [81, 88, 48, 90, 56, 50, 58, 101, 130, 64, 35, 46], # [20, 41, 38, 32, 71, 59, 82, 43, 78, 55, 47, 37], # [24, 95, 27, 44, 65, 12, 28, 102, 98, 23, 14, 19], # [17, 16, 2, 1, 9, 13, 68, 4, 139, 7, 21, 109], # [76, 26, 8, 138, 10, 29, 31, 54, 6, 36, 3, 49], # [61, 51, 42, 45, 34, 5, 73, 33, 25, 22, 39, 11]] ####################### # sorted head coordinate array by importance (L2 weight magnitude), (layer, head) # least 1 --> most 144 importance_coordinates = [[9, 3], [9, 2], [10, 10], [9, 7], [11, 5], [10, 8], [9, 9], [10, 2], [9, 4], [10, 4], [11, 11], [8, 5], [9, 5], [8, 10], [0, 10], [9, 1], [9, 0], [5, 9], [8, 11], [7, 0], [9, 10], [11, 9], [8, 9], [8, 0], [11, 8], [10, 1], [8, 2], [8, 6], [10, 5], [2, 0], [10, 6], [7, 3], [11, 7], [11, 4], [6, 10], [10, 9], [7, 11], [7, 2], [11, 10], [2, 9], [7, 1], [11, 2], [7, 7], [8, 3], [11, 3], [6, 11], [7, 10], [6, 2], [10, 11], [6, 5], [11, 1], [5, 10], [4, 11], [10, 7], [7, 9], [6, 4], [1, 4], [6, 6], [7, 5], [1, 1], [11, 0], [3, 5], [0, 1], [6, 9], [8, 4], [2, 7], [0, 2], [9, 6], [5, 0], [3, 9], [7, 4], [0, 3], [11, 6], [4, 7], [2, 10], [10, 0], [5, 11], [7, 8], [4, 5], [5, 3], [6, 0], [7, 6], [1, 7], [4, 9], [3, 10], [4, 10], [5, 1], [6, 1], [2, 11], [6, 3], [0, 4], [4, 4], [0, 5], [2, 5], [8, 1], [0, 8], [5, 8], [8, 8], [1, 10], [0, 7], [6, 7], [8, 7], [4, 8], [5, 7], [3, 11], [1, 5], [1, 2], [3, 0], [9, 11], [4, 6], [2, 1], [2, 4], [4, 2], [5, 5], [2, 2], [1, 11], [5, 2], [1, 6], [3, 3], [4, 1], [3, 4], [2, 6], [2, 8], [0, 6], [3, 8], [4, 0], [0, 0], [1, 3], [5, 6], [6, 8], [3, 7], [2, 3], [0, 11], [0, 9], [1, 9], [4, 3], [3, 6], [10, 3], [9, 8], [3, 1], [3, 2], [5, 4], [1, 0], [1, 8]] # mask out heads from the most important one if FLAGS.from_most: importance_coordinates.reverse() # two placeholders for the head coordinates, layer, head head_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, ], name='head_mask') layer_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, ], name='layer_mask') model = modeling.BertModel( config=bert_config, is_training=False, input_ids=input_ids_ph, # input_ids, input_mask=input_mask_ph, # input_mask, token_type_ids=segment_ids_ph, # segment_ids, use_one_hot_embeddings=False, head_mask=head_mask_ph, layer_mask=layer_mask_ph) output_layer = model.get_pooled_output() output_weights = tf.get_variable( "output_weights", [num_labels, embed_dim], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable( "output_bias", [num_labels], initializer=tf.zeros_initializer()) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) with tf.compat.v1.variable_scope("loss"): # for stsb if num_labels == 1: logits = tf.squeeze(logits, [-1]) per_example_loss = tf.square(logits - label_ids_ph) loss = tf.reduce_mean(per_example_loss) else: log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids_ph, depth=num_labels, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) loss = tf.reduce_mean(per_example_loss) predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) # metric and summary # metric is tf.metric object, (val, op) metric = metric_fn(per_example_loss, label_ids_ph, logits, num_labels, task_name) metric_name = list(metric.keys()) metric_val = [m[0] for m in metric.values()] metric_op = [m[1] for m in metric.values()] init_checkpoint = FLAGS.init_checkpoint tvars = tf.compat.v1.trainable_variables() saver_init = tf.train.Saver(tvars) # Isolate the variables stored behind the scenes by the metric operation var_metric = [] for key in metric.keys(): var_metric.extend(tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=key)) # Define initializer to initialize/reset running variables metric_vars_initializer = tf.variables_initializer(var_list=var_metric) config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True with tf.compat.v1.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver_init.restore(sess, init_checkpoint) # if number of eval examples < 1000, just load it directly, or load by batch. if num_actual_eval_examples <= 1000: eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=num_actual_eval_examples, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer) # loop over layers, then loop over heads idx = 0 for l in range(n_layers): for h in range(n_heads): cur_l, cur_h = importance_coordinates[idx] coor_mask = importance_coordinates[0:idx + 1] head_mask = [head for _, head in coor_mask] layer_mask = [layer for layer, _ in coor_mask] # if number of eval examples < 1000, just load it directly, or load by batch. if num_actual_eval_examples <= 1000: sess.run(metric_vars_initializer) sess.run(metric_op, feed_dict={input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, head_mask_ph: head_mask, layer_mask_ph: layer_mask}) eval_metric_val = sess.run(metric_val) else: num_batch_eval = num_actual_eval_examples // batch_size \ if num_actual_eval_examples % batch_size == 0 \ else num_actual_eval_examples // batch_size + 1 id_eval = 0 sess.run(metric_vars_initializer) for _ in range(num_batch_eval): eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer, train_idx_offset=id_eval) id_eval += batch_size sess.run(metric_op, feed_dict={input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, head_mask_ph: head_mask, layer_mask_ph: layer_mask}) eval_metric_val = sess.run(metric_val) for name, val in zip(metric_name, eval_metric_val): if name == 'accuracy': output[cur_l][cur_h] = val print("Mask out the {}th head in (Layer {}, Head {}) | {}: {}" .format(idx + 1, cur_l, cur_h, name, val)) idx += 1 joblib.dump(output, folder + save_file)
def main(_): tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) processors = { "cola": ColaProcessor, "mnlim": MnliMProcessor, "mnlimm": MnliMMProcessor, "mrpc": MrpcProcessor, "qnli": QnliProcessor, "qqp": QqpProcessor, "rte": RteProcessor, "sst2": Sst2Processor, "stsb": StsbProcessor, "wnli": WnliProcessor, "ax": AxProcessor, "mnlimdevastest": MnliMDevAsTestProcessor } tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, FLAGS.init_checkpoint) if not FLAGS.do_train and \ not FLAGS.do_eval and \ not FLAGS.do_pred: raise ValueError( "At least one of 'do_train', 'do_eval' or 'do_pred' must be True.") bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) if FLAGS.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (FLAGS.max_seq_length, bert_config.max_position_embeddings)) tf.io.gfile.makedirs(FLAGS.output_dir) task_name = FLAGS.task_name.lower() print("Current task", task_name) if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() # special handling for mnlimdevastest if task_name == 'mnlimdevastest': task_name = 'mnlim' label_list = processor.get_labels() print("Label list of current task", label_list) tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) train_examples = processor.get_train_examples(FLAGS.data_dir) eval_examples = processor.get_dev_examples(FLAGS.data_dir) num_actual_train_examples = len(train_examples) num_actual_eval_examples = len(eval_examples) print("num_actual_train_examples", num_actual_train_examples) print("num_actual_eval_examples", num_actual_eval_examples) if FLAGS.do_pred: test_examples = processor.get_test_examples(FLAGS.data_dir) num_actual_test_examples = len(test_examples) print("num_actual_test_examples", num_actual_test_examples) batch_size = FLAGS.train_batch_size epochs = FLAGS.num_train_epochs embed_dim = FLAGS.hidden_size # hidden size, 768 for BERT-base, 512 for BERT-small seq_length = FLAGS.max_seq_length num_labels = len(label_list) # Define some placeholders for the input input_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_ids') input_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_mask') segment_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='segment_ids') label_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[ None, ], name='label_ids') tf.compat.v1.logging.info("Running swapping layers!") num_train_steps = num_actual_train_examples // batch_size * epochs num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) # Swapped layers layer_folder_name = FLAGS.layers if FLAGS.layers is None: raise ValueError("In swapping experiments, layers must not be None. ") layers = list(map(int, FLAGS.layers.split(','))) layer_order = list(range(12)) if embed_dim == 768 else list(range(6)) layer1, layer2 = layers[0], layers[1] if len(layers) == 2: layers.sort() if layer1 != layer2: layer_order[layer1], layer_order[layer2] = layer_order[ layer2], layer_order[layer1] else: raise ValueError("Two layers should be different! ") else: layer_order = layer_order[:layers[0]] + layers[::-1] + layer_order[ layers[-1] + 1:] print("Current layer order: ", layer_order) # this placeholder is to control the flag for the dropout keep_prob_ph = tf.compat.v1.placeholder(tf.float32, name="keep_prob") is_training_ph = tf.compat.v1.placeholder(tf.bool, name='is_training') model = modeling.BertModel( config=bert_config, is_training=is_training_ph, input_ids=input_ids_ph, # input_ids, input_mask=input_mask_ph, # input_mask, token_type_ids=segment_ids_ph, # segment_ids, use_one_hot_embeddings=False, layer_order=layer_order, use_estimator=False) output_layer = model.get_pooled_output() output_layer = tf.nn.dropout(output_layer, keep_prob=keep_prob_ph) output_weights = tf.get_variable( "output_weights", [num_labels, embed_dim], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable("output_bias", [num_labels], initializer=tf.zeros_initializer()) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) with tf.compat.v1.variable_scope("loss"): # for stsb if num_labels == 1: logits = tf.squeeze(logits, [-1]) per_example_loss = tf.square(logits - label_ids_ph) loss = tf.reduce_mean(per_example_loss) else: log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids_ph, depth=num_labels, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) loss = tf.reduce_mean(per_example_loss) predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) # metric and summary # metric is tf.metric object, (val, op) metric = metric_fn(per_example_loss, label_ids_ph, logits, num_labels, task_name) metric_name = list(metric.keys()) metric_val = [m[0] for m in metric.values()] metric_op = [m[1] for m in metric.values()] metric_phs = [ tf.compat.v1.placeholder(tf.float32, name="{}_ph".format(key)) for key in metric.keys() ] summaries = [ tf.compat.v1.summary.scalar(key, metric_phs[i]) for i, key in enumerate(metric.keys()) ] train_summary_total = tf.summary.merge(summaries) eval_summary_total = tf.summary.merge(summaries) log_dir = FLAGS.output_dir + 'layer_{}/'.format(layer_folder_name) init_checkpoint = FLAGS.init_checkpoint tvars = tf.compat.v1.trainable_variables() var_init = [ v for v in tvars if 'output_weights' not in v.name and 'output_bias' not in v.name ] var_output = [ v for v in tvars if 'output_weights' in v.name or "output_bias" in v.name ] if not FLAGS.load_from_finetuned: # Init from Model0 saver_init = tf.train.Saver(var_init) else: # Init from Model1 saver_init = tf.train.Saver(var_init + var_output) var_train = var_init + var_output print("Training parameters") for v in var_train: print(v) train_op = optimization.create_optimizer(loss=loss, init_lr=FLAGS.learning_rate, num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, use_tpu=False, tvars=var_train) saver_all = tf.train.Saver(var_list=var_train, max_to_keep=1) # Isolate the variables stored behind the scenes by the metric operation var_metric = [] for key in metric.keys(): var_metric.extend( tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=key)) # Define initializer to initialize/reset running variables metric_vars_initializer = tf.variables_initializer(var_list=var_metric) config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True with tf.compat.v1.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver_init.restore(sess, init_checkpoint) writer = tf.compat.v1.summary.FileWriter(log_dir + 'log/train/', sess.graph) writer_eval = tf.compat.v1.summary.FileWriter(log_dir + 'log/eval/') # if number of eval examples < 1000, just load it directly, or load by batch. if num_actual_eval_examples <= 1000: eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=num_actual_eval_examples, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer) start_metric = {"eval_{}".format(key): 0 for key in metric_name} if FLAGS.do_train: tf.logging.info("***** Run training *****") step = 1 for n in range(epochs): np.random.shuffle(train_examples) num_batch = num_actual_train_examples // batch_size if num_actual_train_examples % batch_size == 0 \ else num_actual_train_examples // batch_size + 1 id = 0 for b in range(num_batch): input_ids, input_mask, \ segment_ids, label_ids, is_real_example = generate_ph_input(batch_size=batch_size, seq_length=seq_length, examples=train_examples, label_list=label_list, tokenizer=tokenizer, train_idx_offset=id) id += batch_size sess.run(metric_vars_initializer) sess.run([train_op] + metric_op, feed_dict={ input_ids_ph: input_ids, input_mask_ph: input_mask, segment_ids_ph: segment_ids, label_ids_ph: label_ids, is_training_ph: True, keep_prob_ph: 0.9 }) train_metric_val = sess.run(metric_val) train_summary_str = sess.run( train_summary_total, feed_dict={ ph: value for ph, value in zip(metric_phs, train_metric_val) }) writer.add_summary(train_summary_str, step) if step % 100 == 0 or step % num_batch == 0 or step == 1: # evaluate on dev set if num_actual_eval_examples <= 1000: sess.run(metric_vars_initializer) sess.run(metric_op, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, is_training_ph: False, keep_prob_ph: 1 }) eval_metric_val = sess.run(metric_val) eval_summary_str = sess.run( eval_summary_total, feed_dict={ ph: value for ph, value in zip( metric_phs, eval_metric_val) }) else: num_batch_eval = num_actual_eval_examples // batch_size \ if num_actual_eval_examples % batch_size == 0 \ else num_actual_eval_examples // batch_size + 1 id_eval = 0 sess.run(metric_vars_initializer) for _ in range(num_batch_eval): eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer, train_idx_offset=id_eval) id_eval += batch_size sess.run(metric_op, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, is_training_ph: False, keep_prob_ph: 1 }) eval_metric_val = sess.run(metric_val) eval_summary_str = sess.run( eval_summary_total, feed_dict={ ph: value for ph, value in zip( metric_phs, eval_metric_val) }) writer_eval.add_summary(eval_summary_str, step) if step == 1: for key, val in zip(metric_name, eval_metric_val): start_metric["eval_{}".format(key)] = val if step % 100 == 0 or step % num_batch == 0 or step == 1: train_metric_list = [] for i in range(len(train_metric_val)): if metric_name[i] == 'loss': train_metric_list.append( "{}: %2.4f".format(metric_name[i]) % train_metric_val[i]) else: train_metric_list.append( "{}: %.4f".format(metric_name[i]) % train_metric_val[i]) train_str = 'Train ' + '|'.join(train_metric_list) eval_metric_list = [] for i in range(len(eval_metric_val)): if metric_name[i] == 'loss': eval_metric_list.append( "{}: %2.4f".format(metric_name[i]) % eval_metric_val[i]) else: eval_metric_list.append( "{}: %.4f".format(metric_name[i]) % eval_metric_val[i]) eval_str = 'Eval ' + '|'.join(eval_metric_list) print( "Swap layer order {} | Epoch: %4d/%4d | Batch: %4d/%4d | {} | {}" .format(layer_folder_name, train_str, eval_str) % (n, epochs, b, num_batch)) if step % num_batch == 0: saver_all.save(sess, log_dir + 'swap_{}'.format(layer_folder_name), global_step=step) step += 1 writer.close() writer_eval.close() end_metric = {"eval_{}".format(key): 0 for key in metric_name} if FLAGS.do_eval: tf.logging.info("***** Run evaluation *****") if num_actual_eval_examples <= 1000: sess.run(metric_vars_initializer) sess.run(metric_op, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, is_training_ph: False, keep_prob_ph: 1 }) eval_metric_val = sess.run(metric_val) preds = sess.run(predictions, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, is_training_ph: False, keep_prob_ph: 1 }) eval_label_ids_lst = eval_label_ids else: num_batch_eval = num_actual_eval_examples // batch_size \ if num_actual_eval_examples % batch_size == 0 \ else num_actual_eval_examples // batch_size + 1 id_eval = 0 preds = np.zeros(num_actual_eval_examples) eval_label_ids_lst = np.zeros(num_actual_eval_examples) sess.run(metric_vars_initializer) for i in range(num_batch_eval): eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer, train_idx_offset=id_eval) id_eval += batch_size sess.run(metric_op, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, is_training_ph: False, keep_prob_ph: 1 }) pred = sess.run(predictions, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, is_training_ph: False, keep_prob_ph: 1 }) preds[i * batch_size:min(id_eval, num_actual_eval_examples )] = pred[:] eval_label_ids_lst[i * batch_size:min( id_eval, num_actual_eval_examples)] = eval_label_ids[:] eval_metric_val = sess.run(metric_val) for key, val in zip(metric_name, eval_metric_val): end_metric["eval_{}".format(key)] = val output_predict_file = os.path.join(log_dir, 'dev_predictions.tsv') writer_output = tf.io.gfile.GFile(output_predict_file, "w") preds = preds.astype(int) eval_label_ids_lst = eval_label_ids_lst.astype(int) num_written_lines = 0 if task_name != 'stsb': writer_output.write( "ID \t Label ID \t Label \t Ground Truth ID \t Ground Truth \n" ) else: writer_output.write("ID \t Label \n") for (i, pred) in enumerate(preds): if task_name != 'stsb': writer_output.write("{} \t {} \t {} \t {} \t {} \n".format( i, pred, label_list[pred], eval_label_ids_lst[i], label_list[eval_label_ids_lst[i]])) else: writer_output.write("{} \t {} \n".format( num_written_lines, pred)) writer_output.close() tf.logging.info("***** Finish writing *****") print("Start metric", start_metric) print("End metric", end_metric) test_metric = {"test_{}".format(key): 0 for key in metric_name} if FLAGS.do_pred: # if number of test examples < 1000, just load it directly, or load by batch. # prediction tf.logging.info("***** Predict results *****") if num_actual_test_examples <= 1000: test_input_ids, test_input_mask, test_segment_ids, \ test_label_ids, test_is_real_example = generate_ph_input(batch_size=num_actual_test_examples, seq_length=seq_length, examples=test_examples, label_list=label_list, tokenizer=tokenizer) sess.run(metric_vars_initializer) sess.run(metric_op, feed_dict={ input_ids_ph: test_input_ids, input_mask_ph: test_input_mask, segment_ids_ph: test_segment_ids, label_ids_ph: test_label_ids, is_training_ph: False, keep_prob_ph: 1 }) test_metric_val = sess.run(metric_val) preds = sess.run(predictions, feed_dict={ input_ids_ph: test_input_ids, input_mask_ph: test_input_mask, segment_ids_ph: test_segment_ids, label_ids_ph: test_label_ids, is_training_ph: False, keep_prob_ph: 1 }) test_label_ids_lst = test_label_ids else: num_batch_test = num_actual_test_examples // batch_size \ if num_actual_test_examples % batch_size == 0 \ else num_actual_test_examples // batch_size + 1 id_test = 0 preds = np.zeros(num_actual_test_examples) test_label_ids_lst = np.zeros(num_actual_test_examples) sess.run(metric_vars_initializer) for i in range(num_batch_test): test_input_ids, test_input_mask, test_segment_ids, \ test_label_ids, test_is_real_example = generate_ph_input(batch_size=batch_size, seq_length=seq_length, examples=test_examples, label_list=label_list, tokenizer=tokenizer, train_idx_offset=id_test) id_test += batch_size sess.run(metric_op, feed_dict={ input_ids_ph: test_input_ids, input_mask_ph: test_input_mask, segment_ids_ph: test_segment_ids, label_ids_ph: test_label_ids, is_training_ph: False, keep_prob_ph: 1 }) pred = sess.run(predictions, feed_dict={ input_ids_ph: test_input_ids, input_mask_ph: test_input_mask, segment_ids_ph: test_segment_ids, label_ids_ph: test_label_ids, is_training_ph: False, keep_prob_ph: 1 }) preds[i * batch_size:min(id_test, num_actual_test_examples )] = pred[:] test_label_ids_lst[i * batch_size:min( id_test, num_actual_test_examples)] = test_label_ids[:] test_metric_val = sess.run(metric_val) for key, val in zip(metric_name, test_metric_val): test_metric["test_{}".format(key)] = val output_predict_file = os.path.join(log_dir, 'test_predictions.tsv') submit_predict_file = os.path.join( log_dir, "{}.tsv".format(standard_file_name[task_name])) writer_output = tf.io.gfile.GFile(output_predict_file, "w") writer_submit = tf.io.gfile.GFile(submit_predict_file, 'w') preds = preds.astype(int) test_label_ids_lst = test_label_ids_lst.astype(int) num_written_lines = 0 if task_name != 'stsb': writer_output.write( "ID \t Label ID \t Label \t Ground Truth ID \t Ground Truth \n" ) else: writer_output.write("ID \t Label \n") writer_submit.write("ID \t Label \n") for (i, pred) in enumerate(preds): if task_name != 'stsb': writer_output.write("{} \t {} \t {} \t {} \t {} \n".format( i, pred, label_list[pred], test_label_ids_lst[i], label_list[test_label_ids_lst[i]])) writer_submit.write("{} \t {} \n".format( i, label_list[pred])) else: writer_output.write("{} \t {} \n".format( num_written_lines, pred)) writer_submit.write("{} \t {} \n".format(i, pred)) writer_output.close() writer_submit.close() tf.logging.info("***** Finish writing *****") with tf.io.gfile.GFile(FLAGS.output_dir + 'results.txt', 'a') as writer: eval_start, eval_end, test_end = [], [], [] for metric in metric_name: if metric != 'loss': eval_start.append("{}: %.4f".format(metric) % start_metric["eval_{}".format(metric)]) eval_end.append("{}: %.4f".format(metric) % end_metric["eval_{}".format(metric)]) test_end.append("{}: %.4f".format(metric) % test_metric["test_{}".format(metric)]) writer.write( "Swap layer order {}: Dev start: {} | Dev end: {} | Test end: {}\n" .format(layer_folder_name, ','.join(eval_start), ','.join(eval_end), ','.join(test_end)))
def main(_): tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) processors = { "cola": ColaProcessor, "mnlim": MnliMProcessor, "mnlimm": MnliMMProcessor, "mrpc": MrpcProcessor, "qnli": QnliProcessor, "qqp": QqpProcessor, "rte": RteProcessor, "sst2": Sst2Processor, "stsb": StsbProcessor, "wnli": WnliProcessor, "ax": AxProcessor, "mnlimdevastest": MnliMDevAsTestProcessor, "semeval": SemEvalProcessor } tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, FLAGS.init_checkpoint) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) if FLAGS.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (FLAGS.max_seq_length, bert_config.max_position_embeddings)) tf.io.gfile.makedirs(FLAGS.output_dir) task_name = FLAGS.task_name.lower() print("Current task", task_name) if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() label_list = processor.get_labels() print("Label list of current task", label_list) tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) eval_examples = processor.get_dev_examples(FLAGS.data_dir) num_actual_eval_examples = len(eval_examples) print("num_actual_eval_examples", num_actual_eval_examples) n_layers = FLAGS.n_layers embed_dim = FLAGS.hidden_size # hidden size, 768 for BERT-base, 512 for BERT-small seq_length = FLAGS.max_seq_length # Define some placeholders for the input input_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_ids') input_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_mask') segment_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='segment_ids') tf.compat.v1.logging.info( "Running tracking hidden token embeddings through the model!") # for faster calculation # np.random.seed(0) # random_choice_idx = np.random.choice(num_actual_eval_examples, 1000, replace=False) # eval_examples = [eval_examples[idx] for idx in random_choice_idx] # num_actual_eval_examples = 1000 if FLAGS.feed_ones: num_actual_eval_examples = 1 print("here we only take {} samples.".format(num_actual_eval_examples)) tf.compat.v1.logging.info( "For faster calculation, we reduce the example size! ") save_file = 'track_embeddings.pickle' # output = {"word_embeddings": np.zeros((num_actual_eval_examples, seq_length, embed_dim)), # "final_embeddings": np.zeros((num_actual_eval_examples, seq_length, embed_dim)), # "layer_input": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)), # "layer_self_attention": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)), # "layer_self_attention_ff": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)), # "layer_attention_before_ln": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)), # "layer_attention": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)), # "layer_ffgelu": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim * 4)), # "layer_mlp": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)), # "layer_ffn_before_ln": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)), # "layer_output": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)), # "sentence": []} output = { "final_embeddings": np.zeros((num_actual_eval_examples, seq_length, embed_dim)), "layer_self_attention_ff": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)), "layer_attention": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)), "layer_mlp": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)), "layer_output": np.zeros((n_layers, num_actual_eval_examples, seq_length, embed_dim)), "sentence": [] } layers_cancel_skip_connection = [] if FLAGS.layers_cancel_skip_connection is not None: layers_cancel_skip_connection = list( map(int, FLAGS.layers_cancel_skip_connection.split(','))) layers_cancel_skip_connection.sort() print("Layers need to cancel skip-connection: ", layers_cancel_skip_connection) layers_use_relu = [] if FLAGS.layers_use_relu is not None: layers_use_relu = list(map(int, FLAGS.layers_use_relu.split(','))) layers_use_relu.sort() print("Layers need to use ReLU: ", layers_use_relu) model = modeling.BertModel( config=bert_config, is_training=False, input_ids=input_ids_ph, # input_ids, input_mask=input_mask_ph, # input_mask, token_type_ids=segment_ids_ph, # segment_ids, use_one_hot_embeddings=False, cancel_skip_connection=layers_cancel_skip_connection, layer_use_relu=layers_use_relu, feed_same=FLAGS.feed_same) # extract all the intermediate outputs # word_embeddings = model.get_in_embeds()[0] final_embeddings = model.get_embedding_output() # layer_self_attention = model.get_all_head_output() layer_self_attention_ff = model.get_all_attention_before_dropout() # layer_attention_before_ln = model.get_all_attention_before_layernorm() layer_attention = model.get_all_layer_tokens_beforeMLP() # layer_ffgelu = model.get_all_layer_tokens_after_FFGeLU() layer_mlp = model.get_all_layer_tokens_afterMLP() # layer_ffn_before_ln = model.get_all_ffn_before_layernorm() layer_output = model.get_all_encoder_layers() # load weights init_checkpoint = FLAGS.init_checkpoint tvars = tf.compat.v1.trainable_variables() saver_init = tf.train.Saver(tvars) config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True with tf.compat.v1.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver_init.restore(sess, init_checkpoint) if FLAGS.feed_ones: input_ids = np.zeros([1, seq_length]) input_mask = np.ones([1, seq_length]) segment_ids = np.zeros([1, seq_length]) else: input_ids, input_mask, segment_ids, \ label_ids, is_real_example, input_tokens = generate_ph_input(batch_size=num_actual_eval_examples, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer, return_tokens=True) # word_embeddings_val, final_embeddings_val, \ # layer_self_attention_val, layer_self_attention_ff_val, \ # layer_attention_before_ln_val, layer_attention_val, \ # layer_ffgelu_val, layer_mlp_val, layer_ffn_before_ln_val, \ # layer_output_val = sess.run([word_embeddings, final_embeddings, # layer_self_attention, layer_self_attention_ff, # layer_attention_before_ln, layer_attention, # layer_ffgelu, layer_mlp, layer_ffn_before_ln, layer_output], # feed_dict={input_ids_ph: input_ids, # input_mask_ph: input_mask, # segment_ids_ph: segment_ids}) final_embeddings_val, \ layer_self_attention_ff_val, layer_attention_val, layer_mlp_val, \ layer_output_val = sess.run([final_embeddings, layer_self_attention_ff, layer_attention, layer_mlp, layer_output], feed_dict={input_ids_ph: input_ids, input_mask_ph: input_mask, segment_ids_ph: segment_ids}) # assign values to output dict # output["word_embeddings"] = np.reshape(word_embeddings_val, # [num_actual_eval_examples, seq_length, embed_dim]) output["final_embeddings"] = np.reshape( final_embeddings_val, [num_actual_eval_examples, seq_length, embed_dim]) for layer in tqdm(range(n_layers)): # if FLAGS.feed_same: # output["layer_input"][layer] = np.reshape(final_embeddings_val, # [num_actual_eval_examples, seq_length, embed_dim]) # else: # if layer == 0: # output["layer_input"][layer] = np.reshape(final_embeddings_val, # [num_actual_eval_examples, seq_length, embed_dim]) # else: # output["layer_input"][layer] = np.reshape(layer_output_val[layer - 1], # [num_actual_eval_examples, seq_length, embed_dim]) # output["layer_self_attention"][layer] = np.reshape(layer_self_attention_val[layer], # [num_actual_eval_examples, seq_length, embed_dim]) output["layer_self_attention_ff"][layer] = np.reshape( layer_self_attention_ff_val[layer], [num_actual_eval_examples, seq_length, embed_dim]) # output["layer_attention_before_ln"][layer] = np.reshape(layer_attention_before_ln_val[layer], # [num_actual_eval_examples, seq_length, embed_dim]) output["layer_attention"][layer] = np.reshape( layer_attention_val[layer], [num_actual_eval_examples, seq_length, embed_dim]) # output["layer_ffgelu"][layer] = np.reshape(layer_ffgelu_val[layer], # [num_actual_eval_examples, seq_length, embed_dim * 4]) output["layer_mlp"][layer] = np.reshape( layer_mlp_val[layer], [num_actual_eval_examples, seq_length, embed_dim]) # output["layer_ffn_before_ln"][layer] = np.reshape(layer_ffn_before_ln_val[layer], # [num_actual_eval_examples, seq_length, embed_dim]) output["layer_output"][layer] = np.reshape( layer_output_val[layer], [num_actual_eval_examples, seq_length, embed_dim]) if FLAGS.feed_ones: output["sentence"] = [[''] * seq_length] else: output["sentence"] = input_tokens joblib.dump(output, os.path.join(FLAGS.output_dir, save_file))
def main(_): tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) processors = { "cola": ColaProcessor, "mnlim": MnliMProcessor, "mnlimm": MnliMMProcessor, "mrpc": MrpcProcessor, "qnli": QnliProcessor, "qqp": QqpProcessor, "rte": RteProcessor, "sst2": Sst2Processor, "stsb": StsbProcessor, "wnli": WnliProcessor, "ax": AxProcessor, "mnlimdevastest": MnliMDevAsTestProcessor } tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, FLAGS.init_checkpoint) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) if FLAGS.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (FLAGS.max_seq_length, bert_config.max_position_embeddings)) tf.io.gfile.makedirs(FLAGS.output_dir) task_name = FLAGS.task_name.lower() print("Current task", task_name) if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() # special handling for mnlimdevastest if task_name == 'mnlimdevastest': task_name = 'mnlim' label_list = processor.get_labels() print("Label list of current task", label_list) tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) train_examples = processor.get_train_examples(FLAGS.data_dir) eval_examples = processor.get_dev_examples(FLAGS.data_dir) num_actual_train_examples = len(train_examples) num_actual_eval_examples = len(eval_examples) print("num_actual_train_examples", num_actual_train_examples) print("num_actual_eval_examples", num_actual_eval_examples) batch_size = FLAGS.train_batch_size epochs = FLAGS.num_train_epochs embed_dim = FLAGS.hidden_size # hidden size, 768 for BERT-base, 512 for BERT-small seq_length = FLAGS.max_seq_length # Define some placeholders for the input input_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_ids') input_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_mask') segment_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='segment_ids') # train an independent linear model to approximate the projection of MLP tf.compat.v1.logging.info("Training Approximator!") # get the layer(s) which need replacement if FLAGS.layers is None: raise ValueError( "In training non-linearity experiments, layers must not be None. ") layer_folder_name = FLAGS.layers layers = list(map(int, FLAGS.layers.split(','))) layers.sort() if len(layers) != 1: raise ValueError("Here it allows only one single layer. ") approximated_layer = layers[0] print("Current approximated layer: ", approximated_layer) approximator_setting = FLAGS.approximator_setting layers_cancel_skip_connection = [] if FLAGS.layers_cancel_skip_connection is not None: layers_cancel_skip_connection = list( map(int, FLAGS.layers_cancel_skip_connection.split(','))) layers_cancel_skip_connection.sort() print("Layers need to cancel skip-connection: ", layers_cancel_skip_connection) layers_use_relu = [] if FLAGS.layers_use_relu is not None: layers_use_relu = list(map(int, FLAGS.layers_use_relu.split(','))) layers_use_relu.sort() print("Layers need to use ReLU: ", layers_use_relu) model = modeling.BertModel( config=bert_config, is_training=False, input_ids=input_ids_ph, # input_ids, input_mask=input_mask_ph, # input_mask, token_type_ids=segment_ids_ph, # segment_ids, use_one_hot_embeddings=False, approximator_setting=approximator_setting, cancel_skip_connection=layers_cancel_skip_connection, layer_use_relu=layers_use_relu) # define the input and output according to the approximated part if FLAGS.approximate_part == 'mlp': # only FFGeLU and FF in FFN, without dropout and layernorm x = model.get_all_layer_tokens_beforeMLP()[approximated_layer] y = model.get_all_layer_tokens_afterMLP()[approximated_layer] elif FLAGS.approximate_part == 'ffgelu': # only FFGeLU in FFN x = model.get_all_layer_tokens_beforeMLP()[approximated_layer] y = model.get_all_layer_tokens_after_FFGeLU()[approximated_layer] elif FLAGS.approximate_part == 'self_attention': # only self-attention part, without linear layer, dropout and layernorm x = model.get_all_encoder_layers()[ approximated_layer - 1] if approximated_layer > 0 else model.get_embedding_output() y = model.get_all_head_output()[approximated_layer] elif FLAGS.approximate_part == 'self_attention_ff': # only self-attention + linear part part, without dropout and layernorm x = model.get_all_encoder_layers()[ approximated_layer - 1] if approximated_layer > 0 else model.get_embedding_output() y = model.get_all_attention_before_dropout()[approximated_layer] elif FLAGS.approximate_part == 'ff_after_self_attention': # only the linear layer after self-attention x = model.get_all_head_output()[approximated_layer] y = model.get_all_attention_before_dropout()[approximated_layer] elif FLAGS.approximate_part == 'attention_before_ln': # only self-attention + linear part part, before layernorm x = model.get_all_encoder_layers()[ approximated_layer - 1] if approximated_layer > 0 else model.get_embedding_output() y = model.get_all_attention_before_layernorm()[approximated_layer] elif FLAGS.approximate_part == 'attention': # whole attention block including dropout and layernorm x = model.get_all_encoder_layers()[ approximated_layer - 1] if approximated_layer > 0 else model.get_embedding_output() y = model.get_all_layer_tokens_beforeMLP()[approximated_layer] elif FLAGS.approximate_part == 'ffn': # whole FFN block including dropout and layernorm x = model.get_all_layer_tokens_beforeMLP()[approximated_layer] y = model.get_all_encoder_layers()[approximated_layer] elif FLAGS.approximate_part == 'ffn_before_ln': # only FFGeLU and FF in FFN, before layernorm x = model.get_all_layer_tokens_beforeMLP()[approximated_layer] y = model.get_all_ffn_before_layernorm()[approximated_layer] elif FLAGS.approximate_part == 'encoder': # whole encoder including dropout and layernorm x = model.get_all_encoder_layers()[ approximated_layer - 1] if approximated_layer > 0 else model.get_embedding_output() y = model.get_all_encoder_layers()[approximated_layer] else: raise ValueError("Need to specify correct value. ") if FLAGS.approximator_setting in [ 'HS_MLP', 'HS*4+HS_MLP', 'HS_Self_Attention', 'HS_Self_Attention_FF', 'HS_FFN', 'HS_Attention', 'HS_Encoder', 'HS_Attention_Before_LN', 'HS_FFN_Before_LN', 'HS_FF_After_Self_Attention' ]: approximator_dim = embed_dim else: # HS*4_FFGeLU approximator_dim = embed_dim * 4 x = tf.reshape(x, [-1, seq_length, embed_dim]) y = tf.reshape(y, [-1, seq_length, approximator_dim]) # non-linearity if not FLAGS.use_nonlinear_approximator: y_pred = approximator.linear_approximator( input=x, approximated_layer=approximated_layer, hidden_size=approximator_dim) else: y_pred = approximator.nonlinear_approximator( input=x, approximated_layer=approximated_layer, hidden_size=approximator_dim, num_layer=1, use_dropout=FLAGS.use_dropout, dropout_p=0.2) # with tf.compat.v1.variable_scope("bert/encoder/layer_{}/non-linearity".format(approximated_layer)): # y_pred = tf.layers.dense( # x, # approximator_dim, # kernel_initializer=tf.truncated_normal_initializer(stddev=0.02)) # NOTE, here is alias of cosine_proximity, [0, 1]. Not the TF2.0 version, which is [-1. 0]. similarity_token = tf.keras.losses.cosine_similarity(y, y_pred, axis=-1) l2_distance_token = tf.norm(y - y_pred, ord=2, axis=-1) # loss: similarity or mean squared error or l2 norm # when using cosine similarity, the target becomes minimizing to 0. per_example_loss = 1. - tf.reduce_mean(similarity_token, axis=-1) \ if FLAGS.loss == 'cosine' \ else (tf.reduce_mean(tf.keras.losses.mse(y, y_pred), axis=-1) if FLAGS.loss == 'mse' else (tf.reduce_mean(l2_distance_token, axis=-1) if FLAGS.loss == 'l2' else 0.0)) loss = tf.reduce_mean(per_example_loss, axis=-1) similarity = tf.reduce_mean(tf.reduce_mean(similarity_token, axis=-1), axis=-1) l2_distance = tf.reduce_mean(tf.reduce_mean(l2_distance_token, axis=-1), axis=-1) # for summary loss_ph = tf.compat.v1.placeholder(tf.float32, name='loss') similarity_ph = tf.compat.v1.placeholder(tf.float32, name='similarity') l2_distance_ph = tf.compat.v1.placeholder(tf.float32, name='l2_distance') loss_sum = tf.summary.scalar("loss", loss_ph) similarity_sum = tf.summary.scalar("similarity", similarity_ph) l2_distance_sum = tf.summary.scalar("l2_distance", l2_distance_ph) train_summary_total = tf.summary.merge( [loss_sum, similarity_sum, l2_distance_sum]) eval_summary_total = tf.summary.merge( [loss_sum, similarity_sum, l2_distance_sum]) log_dir = FLAGS.output_dir + 'layer_{}/'.format(layer_folder_name) # define optimizer optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) # load weights init_checkpoint = FLAGS.init_checkpoint tvars = tf.compat.v1.trainable_variables() var_init = [var for var in tvars if "non-linearity" not in var.name] saver_init = tf.train.Saver(var_init) # take variables related to approximators var_approximator = [var for var in tvars if "non-linearity" in var.name] print("Training variables") for var in var_approximator: print(var) # define optimizer op and saver optimizer_op = optimizer.minimize(loss=loss, var_list=var_approximator) saver_approximator = tf.train.Saver(var_list=var_approximator, max_to_keep=1) # add this GPU settings config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True with tf.compat.v1.Session(config=config) as sess: sess.run(tf.compat.v1.global_variables_initializer()) saver_init.restore(sess, init_checkpoint) writer = tf.compat.v1.summary.FileWriter(log_dir + 'log/train/', sess.graph) writer_eval = tf.compat.v1.summary.FileWriter(log_dir + 'log/eval/') # if number of eval examples < 1000, just load it directly, or load by batch. if num_actual_eval_examples <= 1000: eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example, \ eval_input_tokens = generate_ph_input(batch_size=num_actual_eval_examples, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer, return_tokens=True) step = 1 for n in range(epochs): np.random.shuffle(train_examples) num_batch = num_actual_train_examples // batch_size if num_actual_train_examples % batch_size == 0 \ else num_actual_train_examples // batch_size + 1 id = 0 for b in range(num_batch): input_ids, input_mask, \ segment_ids, label_ids, is_real_example = generate_ph_input(batch_size=batch_size, seq_length=seq_length, examples=train_examples, label_list=label_list, tokenizer=tokenizer, train_idx_offset=id) id += batch_size train_loss, train_similarity, train_l2_distance, _ = sess.run( [loss, similarity, l2_distance, optimizer_op], feed_dict={ input_ids_ph: input_ids, input_mask_ph: input_mask, segment_ids_ph: segment_ids }) train_summary_str = sess.run(train_summary_total, feed_dict={ loss_ph: train_loss, similarity_ph: train_similarity, l2_distance_ph: train_l2_distance }) writer.add_summary(train_summary_str, step) # evaluate on dev set if step % 100 == 0 or step % num_batch == 0 or step == 1: if num_actual_eval_examples <= 1000: eval_loss, eval_similarity, eval_l2_distance = sess.run( [loss, similarity, l2_distance], feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids }) eval_summary_str = sess.run(eval_summary_total, feed_dict={ loss_ph: eval_loss, similarity_ph: eval_similarity, l2_distance_ph: eval_l2_distance }) else: num_batch_eval = num_actual_eval_examples // batch_size \ if num_actual_eval_examples % batch_size == 0 \ else num_actual_eval_examples // batch_size + 1 id_eval = 0 acc_loss, acc_similarity, acc_l2_distance = 0, 0, 0 for _ in range(num_batch_eval): eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer, train_idx_offset=id_eval) id_eval += batch_size eval_loss, eval_similarity, eval_l2_distance = sess.run( [loss, similarity, l2_distance], feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids }) acc_loss += eval_loss acc_similarity += eval_similarity acc_l2_distance += eval_l2_distance eval_loss = acc_loss / num_batch_eval eval_similarity = acc_similarity / num_batch_eval eval_l2_distance = acc_l2_distance / num_batch_eval eval_summary_str = sess.run(eval_summary_total, feed_dict={ loss_ph: eval_loss, similarity_ph: eval_similarity, l2_distance_ph: eval_l2_distance }) writer_eval.add_summary(eval_summary_str, step) if step % 100 == 0 or step % num_batch == 0 or step == 1: print( "Approximating layer: %2d | Epoch: %4d/%4d | Batch: %4d/%4d | " "Train loss: %2.4f | Train similarity/L2 distance: %.4f/%2.4f | " "Eval loss: %2.4f | Eval similarity/L2 distance: %.4f/%2.4f" % (approximated_layer, n, epochs, b, num_batch, train_loss, train_similarity, train_l2_distance, eval_loss, eval_similarity, eval_l2_distance)) if step % num_batch == 0: saver_approximator.save( sess, log_dir + 'approximator_{}'.format(layer_folder_name), global_step=step) step += 1 writer.close() writer_eval.close() # eval eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example, \ eval_input_tokens = generate_ph_input(batch_size=num_actual_eval_examples, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer, return_tokens=True) eval_token_length = [len(elem) for elem in eval_input_tokens] ground_truths, predictions, cosine_similarity_all = sess.run( [y, y_pred, similarity], feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids }) cosine_similarity_actual_token = 0 cosine_similarity_padding = 0 cosine_similarity_all = 0 num_sample_no_padding = 0 for i in range(num_actual_eval_examples): if eval_token_length[i] >= seq_length: num_sample_no_padding += 1 continue cosine_similarity_actual_token += np.mean([ 1 - cosine(ground_truths[i][j], predictions[i][j]) for j in range(eval_token_length[i]) ]) cosine_similarity_padding += np.mean([ 1 - cosine(ground_truths[i][j], predictions[i][j]) for j in range(eval_token_length[i], seq_length) ]) cosine_similarity_all += np.mean([ 1 - cosine(ground_truths[i][j], predictions[i][j]) for j in range(seq_length) ]) cosine_similarity_actual_token /= (num_actual_eval_examples - num_sample_no_padding) cosine_similarity_padding /= (num_actual_eval_examples - num_sample_no_padding) cosine_similarity_all /= (num_actual_eval_examples - num_sample_no_padding) print("Skip {} samples without paddings".format(num_sample_no_padding)) with tf.io.gfile.GFile(FLAGS.output_dir + 'results.txt', 'a') as writer: writer.write( "Approximating layer %2d: Actual: %.4f/%.4f | Pad: %.4f/%.4f | All: %.4f/%.4f\n" % (approximated_layer, cosine_similarity_actual_token, 1 - cosine_similarity_actual_token, cosine_similarity_padding, 1 - cosine_similarity_padding, cosine_similarity_all, 1 - cosine_similarity_all))
def main(_): tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) processors = { "cola": ColaProcessor, "mnlim": MnliMProcessor, "mnlimm": MnliMMProcessor, "mrpc": MrpcProcessor, "qnli": QnliProcessor, "qqp": QqpProcessor, "rte": RteProcessor, "sst2": Sst2Processor, "stsb": StsbProcessor, "wnli": WnliProcessor, "ax": AxProcessor, "mnlimdevastest": MnliMDevAsTestProcessor } tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, FLAGS.init_checkpoint) if not FLAGS.do_train: raise ValueError("At least 'do_train' must be True.") bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) if FLAGS.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (FLAGS.max_seq_length, bert_config.max_position_embeddings)) tf.io.gfile.makedirs(FLAGS.output_dir) task_name = FLAGS.task_name.lower() print("Current task", task_name) if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() # special handling for mnlimdevastest if task_name == 'mnlimdevastest': task_name = 'mnlim' label_list = processor.get_labels() print("Label list of current task", label_list) tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) train_examples = processor.get_train_examples(FLAGS.data_dir) eval_examples = processor.get_dev_examples(FLAGS.data_dir) num_actual_train_examples = len(train_examples) num_actual_eval_examples = len(eval_examples) print("num_actual_train_examples", num_actual_train_examples) print("num_actual_eval_examples", num_actual_eval_examples) batch_size = FLAGS.train_batch_size epochs = FLAGS.num_train_epochs embed_dim = FLAGS.hidden_size # hidden size, 768 for BERT-base, 512 for BERT-small seq_length = FLAGS.max_seq_length num_labels = len(label_list) # Define some placeholders for the input input_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_ids') input_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_mask') segment_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='segment_ids') label_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[ None, ], name='label_ids') tf.compat.v1.logging.info( "Running leave the most important head per layer then fine-tune!") num_train_steps = num_actual_train_examples // batch_size * epochs num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) cur_layer = FLAGS.cur_layer most_important_head = FLAGS.most_important_head print("Current most important head:", cur_layer, most_important_head) # this placeholder is to control the flag for the dropout keep_prob_ph = tf.compat.v1.placeholder(tf.float32, name="keep_prob") is_training_ph = tf.compat.v1.placeholder(tf.bool, name='is_training') # two placeholders for the head coordinates, layer, head head_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[ None, ], name='head_mask') layer_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[ None, ], name='layer_mask') model = modeling.BertModel( config=bert_config, is_training=is_training_ph, input_ids=input_ids_ph, # input_ids, input_mask=input_mask_ph, # input_mask, token_type_ids=segment_ids_ph, # segment_ids, use_one_hot_embeddings=False, use_estimator=False, head_mask=head_mask_ph, layer_mask=layer_mask_ph) output_layer = model.get_pooled_output() output_layer = tf.nn.dropout(output_layer, keep_prob=keep_prob_ph) output_weights = tf.get_variable( "output_weights", [num_labels, embed_dim], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable("output_bias", [num_labels], initializer=tf.zeros_initializer()) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) with tf.compat.v1.variable_scope("loss"): # for stsb if num_labels == 1: logits = tf.squeeze(logits, [-1]) per_example_loss = tf.square(logits - label_ids_ph) loss = tf.reduce_mean(per_example_loss) else: log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids_ph, depth=num_labels, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) loss = tf.reduce_mean(per_example_loss) predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) # metric and summary # metric is tf.metric object, (val, op) metric = metric_fn(per_example_loss, label_ids_ph, logits, num_labels, task_name) metric_name = list(metric.keys()) metric_val = [m[0] for m in metric.values()] metric_op = [m[1] for m in metric.values()] log_dir = FLAGS.output_dir + 'layer_{}_head_{}/'.format( cur_layer, most_important_head) metric_phs = [ tf.compat.v1.placeholder(tf.float32, name="{}_ph".format(key)) for key in metric.keys() ] summaries = [ tf.compat.v1.summary.scalar(key, metric_phs[i]) for i, key in enumerate(metric.keys()) ] train_summary_total = tf.summary.merge(summaries) eval_summary_total = tf.summary.merge(summaries) init_checkpoint = FLAGS.init_checkpoint tvars = tf.compat.v1.trainable_variables() var_init = [ v for v in tvars if 'output_weights' not in v.name and 'output_bias' not in v.name ] var_output = [ v for v in tvars if 'output_weights' in v.name or "output_bias" in v.name ] if not FLAGS.load_from_finetuned: # Init from Model0 saver_init = tf.train.Saver(var_init) else: # Init from Model1 saver_init = tf.train.Saver(var_init + var_output) var_train = var_init + var_output print("Training parameters") for v in var_train: print(v) train_op = optimization.create_optimizer(loss=loss, init_lr=FLAGS.learning_rate, num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, use_tpu=False, tvars=var_train) saver_all = tf.train.Saver(var_list=var_init + var_output, max_to_keep=1) # Isolate the variables stored behind the scenes by the metric operation var_metric = [] for key in metric.keys(): var_metric.extend( tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=key)) # Define initializer to initialize/reset running variables metric_vars_initializer = tf.variables_initializer(var_list=var_metric) config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True with tf.compat.v1.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver_init.restore(sess, init_checkpoint) writer = tf.compat.v1.summary.FileWriter(log_dir + 'log/train/', sess.graph) writer_eval = tf.compat.v1.summary.FileWriter(log_dir + 'log/eval/') # heads need to be masked out in cur_layer head_mask = [head for head in range(12) if head != most_important_head] layer_mask = [cur_layer for _ in range(11)] # if number of eval examples < 1000, just load it directly, or load by batch. if num_actual_eval_examples <= 1000: eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=num_actual_eval_examples, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer) start_metric = {"eval_{}".format(key): 0 for key in metric_name} end_metric = {"eval_{}".format(key): 0 for key in metric_name} if FLAGS.do_train: tf.logging.info("***** Run training *****") step = 1 for n in range(epochs): np.random.shuffle(train_examples) num_batch = num_actual_train_examples // batch_size if num_actual_train_examples % batch_size == 0 \ else num_actual_train_examples // batch_size + 1 id = 0 for b in range(num_batch): input_ids, input_mask, \ segment_ids, label_ids, is_real_example = generate_ph_input(batch_size=batch_size, seq_length=seq_length, examples=train_examples, label_list=label_list, tokenizer=tokenizer, train_idx_offset=id) id += batch_size sess.run(metric_vars_initializer) sess.run([train_op] + metric_op, feed_dict={ input_ids_ph: input_ids, input_mask_ph: input_mask, segment_ids_ph: segment_ids, label_ids_ph: label_ids, is_training_ph: True, keep_prob_ph: 0.9, head_mask_ph: head_mask, layer_mask_ph: layer_mask }) train_metric_val = sess.run(metric_val) train_summary_str = sess.run( train_summary_total, feed_dict={ ph: value for ph, value in zip(metric_phs, train_metric_val) }) writer.add_summary(train_summary_str, step) if step % 100 == 0 or step % num_batch == 0 or step == 1: # evaluate on dev set if num_actual_eval_examples <= 1000: sess.run(metric_vars_initializer) sess.run(metric_op, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, is_training_ph: False, keep_prob_ph: 1, head_mask_ph: head_mask, layer_mask_ph: layer_mask }) eval_metric_val = sess.run(metric_val) eval_summary_str = sess.run( eval_summary_total, feed_dict={ ph: value for ph, value in zip( metric_phs, eval_metric_val) }) else: num_batch_eval = num_actual_eval_examples // batch_size \ if num_actual_eval_examples % batch_size == 0 \ else num_actual_eval_examples // batch_size + 1 id_eval = 0 sess.run(metric_vars_initializer) for _ in range(num_batch_eval): eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer, train_idx_offset=id_eval) id_eval += batch_size sess.run(metric_op, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, is_training_ph: False, keep_prob_ph: 1, head_mask_ph: head_mask, layer_mask_ph: layer_mask }) eval_metric_val = sess.run(metric_val) eval_summary_str = sess.run( eval_summary_total, feed_dict={ ph: value for ph, value in zip( metric_phs, eval_metric_val) }) writer_eval.add_summary(eval_summary_str, step) if step == 1: for key, val in zip(metric_name, eval_metric_val): start_metric["eval_{}".format(key)] = val if step == epochs * num_batch: for key, val in zip(metric_name, eval_metric_val): end_metric["eval_{}".format(key)] = val if step % 100 == 0 or step % num_batch == 0 or step == 1: train_metric_list = [] for i in range(len(train_metric_val)): if metric_name[i] == 'loss': train_metric_list.append( "{}: %2.4f".format(metric_name[i]) % train_metric_val[i]) else: train_metric_list.append( "{}: %.4f".format(metric_name[i]) % train_metric_val[i]) train_str = 'Train ' + '|'.join(train_metric_list) eval_metric_list = [] for i in range(len(eval_metric_val)): if metric_name[i] == 'loss': eval_metric_list.append( "{}: %2.4f".format(metric_name[i]) % eval_metric_val[i]) else: eval_metric_list.append( "{}: %.4f".format(metric_name[i]) % eval_metric_val[i]) eval_str = 'Eval ' + '|'.join(eval_metric_list) print( "Layer {}, leave only one head {} | Epoch: %4d/%4d | Batch: %4d/%4d | {} | {}" .format(cur_layer, most_important_head, train_str, eval_str) % (n, epochs, b, num_batch)) if step % num_batch == 0: saver_all.save(sess, log_dir + 'layer{}_head_{}'.format( cur_layer, most_important_head), global_step=step) step += 1 writer.close() writer_eval.close() print("Start metric", start_metric) print("End metric", end_metric) with tf.io.gfile.GFile(FLAGS.output_dir + 'results.txt', 'a') as writer: eval_start, eval_end = [], [] for metric in metric_name: if metric != 'loss': eval_start.append("{}: %.4f".format(metric) % start_metric["eval_{}".format(metric)]) eval_end.append("{}: %.4f".format(metric) % end_metric["eval_{}".format(metric)]) writer.write( "Layer {}, leave only one head {}: Start: {} | End: {}\n". format(cur_layer, most_important_head, ','.join(eval_start), ','.join(eval_end)))