def main(_): tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) processors = { "cola": ColaProcessor, "mnlim": MnliMProcessor, "mnlimm": MnliMMProcessor, "mrpc": MrpcProcessor, "qnli": QnliProcessor, "qqp": QqpProcessor, "rte": RteProcessor, "sst2": Sst2Processor, "stsb": StsbProcessor, "wnli": WnliProcessor, "ax": AxProcessor, "mnlimdevastest": MnliMDevAsTestProcessor } tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, FLAGS.init_checkpoint) if not FLAGS.do_eval: raise ValueError("At least 'do_eval' must be True.") bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) if FLAGS.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (FLAGS.max_seq_length, bert_config.max_position_embeddings)) tf.io.gfile.makedirs(FLAGS.output_dir) task_name = FLAGS.task_name.lower() print("Current task", task_name) if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() # special handling for mnlimdevastest if task_name == 'mnlimdevastest': task_name = 'mnlim' label_list = processor.get_labels() print("Label list of current task", label_list) tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) eval_examples = processor.get_dev_examples(FLAGS.data_dir) num_actual_eval_examples = len(eval_examples) print("num_actual_eval_examples", num_actual_eval_examples) batch_size = FLAGS.eval_batch_size embed_dim = FLAGS.hidden_size # hidden size, 768 for BERT-base, 512 for BERT-small seq_length = FLAGS.max_seq_length num_labels = len(label_list) # Define some placeholders for the input input_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_ids') input_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_mask') segment_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='segment_ids') label_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[ None, ], name='label_ids') tf.compat.v1.logging.info( "Running single-head masking out and direct evaluation!") # we want to mask out the individual head and then evaluate. So there are 12 layers * 12 heads results. n_layers = 12 n_heads = 12 folder = FLAGS.output_dir save_file = 'single_head_mask.pickle' output = np.zeros((n_layers, n_heads)) # two placeholders for the head coordinates, layer, head head_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[ None, ], name='head_mask') layer_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[ None, ], name='layer_mask') model = modeling.BertModel( config=bert_config, is_training=False, input_ids=input_ids_ph, # input_ids, input_mask=input_mask_ph, # input_mask, token_type_ids=segment_ids_ph, # segment_ids, use_one_hot_embeddings=False, head_mask=head_mask_ph, layer_mask=layer_mask_ph) output_layer = model.get_pooled_output() output_weights = tf.get_variable( "output_weights", [num_labels, embed_dim], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable("output_bias", [num_labels], initializer=tf.zeros_initializer()) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) with tf.compat.v1.variable_scope("loss"): # for stsb if num_labels == 1: logits = tf.squeeze(logits, [-1]) per_example_loss = tf.square(logits - label_ids_ph) loss = tf.reduce_mean(per_example_loss) else: log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids_ph, depth=num_labels, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) loss = tf.reduce_mean(per_example_loss) predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) # metric and summary # metric is tf.metric object, (val, op) metric = metric_fn(per_example_loss, label_ids_ph, logits, num_labels, task_name) metric_name = list(metric.keys()) metric_val = [m[0] for m in metric.values()] metric_op = [m[1] for m in metric.values()] init_checkpoint = FLAGS.init_checkpoint tvars = tf.compat.v1.trainable_variables() saver_init = tf.train.Saver(tvars) # Isolate the variables stored behind the scenes by the metric operation var_metric = [] for key in metric.keys(): var_metric.extend( tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=key)) # Define initializer to initialize/reset running variables metric_vars_initializer = tf.variables_initializer(var_list=var_metric) config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True with tf.compat.v1.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver_init.restore(sess, init_checkpoint) # if number of eval examples < 1000, just load it directly, or load by batch. if num_actual_eval_examples <= 1000: eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=num_actual_eval_examples, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer) # loop over layers, then loop over heads for l in range(n_layers): for h in range(n_heads): cur_l, cur_h = l, h head_mask = [h] layer_mask = [l] # if number of eval examples < 1000, just load it directly, or load by batch. if num_actual_eval_examples <= 1000: sess.run(metric_vars_initializer) sess.run(metric_op, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, head_mask_ph: head_mask, layer_mask_ph: layer_mask }) eval_metric_val = sess.run(metric_val) else: num_batch_eval = num_actual_eval_examples // batch_size \ if num_actual_eval_examples % batch_size == 0 \ else num_actual_eval_examples // batch_size + 1 id_eval = 0 sess.run(metric_vars_initializer) for _ in range(num_batch_eval): eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer, train_idx_offset=id_eval) id_eval += batch_size sess.run(metric_op, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, head_mask_ph: head_mask, layer_mask_ph: layer_mask }) eval_metric_val = sess.run(metric_val) for name, val in zip(metric_name, eval_metric_val): if name == 'accuracy': output[cur_l][cur_h] = val print( "Mask out the head in (Layer {}, Head {}) | {}: {}" .format(cur_l, cur_h, name, val)) joblib.dump(output, folder + save_file)
def main(_): tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) processors = { "cola": ColaProcessor, "mnlim": MnliMProcessor, "mnlimm": MnliMMProcessor, "mrpc": MrpcProcessor, "qnli": QnliProcessor, "qqp": QqpProcessor, "rte": RteProcessor, "sst2": Sst2Processor, "stsb": StsbProcessor, "wnli": WnliProcessor, "ax": AxProcessor, "mnlimdevastest": MnliMDevAsTestProcessor } tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, FLAGS.init_checkpoint) if not FLAGS.do_train and \ not FLAGS.do_eval and \ not FLAGS.do_pred: raise ValueError( "At least one of 'do_train', 'do_eval' or 'do_pred' must be True.") bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) if FLAGS.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (FLAGS.max_seq_length, bert_config.max_position_embeddings)) tf.io.gfile.makedirs(FLAGS.output_dir) task_name = FLAGS.task_name.lower() print("Current task", task_name) if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() # special handling for mnlimdevastest if task_name == 'mnlimdevastest': task_name = 'mnlim' label_list = processor.get_labels() print("Label list of current task", label_list) tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) train_examples = processor.get_train_examples(FLAGS.data_dir) eval_examples = processor.get_dev_examples(FLAGS.data_dir) num_actual_train_examples = len(train_examples) num_actual_eval_examples = len(eval_examples) print("num_actual_train_examples", num_actual_train_examples) print("num_actual_eval_examples", num_actual_eval_examples) if FLAGS.do_pred: test_examples = processor.get_test_examples(FLAGS.data_dir) num_actual_test_examples = len(test_examples) print("num_actual_test_examples", num_actual_test_examples) batch_size = FLAGS.train_batch_size epochs = FLAGS.num_train_epochs embed_dim = FLAGS.hidden_size # hidden size, 768 for BERT-base, 512 for BERT-small seq_length = FLAGS.max_seq_length num_labels = len(label_list) # Define some placeholders for the input input_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_ids') input_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_mask') segment_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='segment_ids') label_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[ None, ], name='label_ids') tf.compat.v1.logging.info("Running swapping layers!") num_train_steps = num_actual_train_examples // batch_size * epochs num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) # Swapped layers layer_folder_name = FLAGS.layers if FLAGS.layers is None: raise ValueError("In swapping experiments, layers must not be None. ") layers = list(map(int, FLAGS.layers.split(','))) layer_order = list(range(12)) if embed_dim == 768 else list(range(6)) layer1, layer2 = layers[0], layers[1] if len(layers) == 2: layers.sort() if layer1 != layer2: layer_order[layer1], layer_order[layer2] = layer_order[ layer2], layer_order[layer1] else: raise ValueError("Two layers should be different! ") else: layer_order = layer_order[:layers[0]] + layers[::-1] + layer_order[ layers[-1] + 1:] print("Current layer order: ", layer_order) # this placeholder is to control the flag for the dropout keep_prob_ph = tf.compat.v1.placeholder(tf.float32, name="keep_prob") is_training_ph = tf.compat.v1.placeholder(tf.bool, name='is_training') model = modeling.BertModel( config=bert_config, is_training=is_training_ph, input_ids=input_ids_ph, # input_ids, input_mask=input_mask_ph, # input_mask, token_type_ids=segment_ids_ph, # segment_ids, use_one_hot_embeddings=False, layer_order=layer_order, use_estimator=False) output_layer = model.get_pooled_output() output_layer = tf.nn.dropout(output_layer, keep_prob=keep_prob_ph) output_weights = tf.get_variable( "output_weights", [num_labels, embed_dim], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable("output_bias", [num_labels], initializer=tf.zeros_initializer()) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) with tf.compat.v1.variable_scope("loss"): # for stsb if num_labels == 1: logits = tf.squeeze(logits, [-1]) per_example_loss = tf.square(logits - label_ids_ph) loss = tf.reduce_mean(per_example_loss) else: log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids_ph, depth=num_labels, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) loss = tf.reduce_mean(per_example_loss) predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) # metric and summary # metric is tf.metric object, (val, op) metric = metric_fn(per_example_loss, label_ids_ph, logits, num_labels, task_name) metric_name = list(metric.keys()) metric_val = [m[0] for m in metric.values()] metric_op = [m[1] for m in metric.values()] metric_phs = [ tf.compat.v1.placeholder(tf.float32, name="{}_ph".format(key)) for key in metric.keys() ] summaries = [ tf.compat.v1.summary.scalar(key, metric_phs[i]) for i, key in enumerate(metric.keys()) ] train_summary_total = tf.summary.merge(summaries) eval_summary_total = tf.summary.merge(summaries) log_dir = FLAGS.output_dir + 'layer_{}/'.format(layer_folder_name) init_checkpoint = FLAGS.init_checkpoint tvars = tf.compat.v1.trainable_variables() var_init = [ v for v in tvars if 'output_weights' not in v.name and 'output_bias' not in v.name ] var_output = [ v for v in tvars if 'output_weights' in v.name or "output_bias" in v.name ] if not FLAGS.load_from_finetuned: # Init from Model0 saver_init = tf.train.Saver(var_init) else: # Init from Model1 saver_init = tf.train.Saver(var_init + var_output) var_train = var_init + var_output print("Training parameters") for v in var_train: print(v) train_op = optimization.create_optimizer(loss=loss, init_lr=FLAGS.learning_rate, num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, use_tpu=False, tvars=var_train) saver_all = tf.train.Saver(var_list=var_train, max_to_keep=1) # Isolate the variables stored behind the scenes by the metric operation var_metric = [] for key in metric.keys(): var_metric.extend( tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=key)) # Define initializer to initialize/reset running variables metric_vars_initializer = tf.variables_initializer(var_list=var_metric) config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True with tf.compat.v1.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver_init.restore(sess, init_checkpoint) writer = tf.compat.v1.summary.FileWriter(log_dir + 'log/train/', sess.graph) writer_eval = tf.compat.v1.summary.FileWriter(log_dir + 'log/eval/') # if number of eval examples < 1000, just load it directly, or load by batch. if num_actual_eval_examples <= 1000: eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=num_actual_eval_examples, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer) start_metric = {"eval_{}".format(key): 0 for key in metric_name} if FLAGS.do_train: tf.logging.info("***** Run training *****") step = 1 for n in range(epochs): np.random.shuffle(train_examples) num_batch = num_actual_train_examples // batch_size if num_actual_train_examples % batch_size == 0 \ else num_actual_train_examples // batch_size + 1 id = 0 for b in range(num_batch): input_ids, input_mask, \ segment_ids, label_ids, is_real_example = generate_ph_input(batch_size=batch_size, seq_length=seq_length, examples=train_examples, label_list=label_list, tokenizer=tokenizer, train_idx_offset=id) id += batch_size sess.run(metric_vars_initializer) sess.run([train_op] + metric_op, feed_dict={ input_ids_ph: input_ids, input_mask_ph: input_mask, segment_ids_ph: segment_ids, label_ids_ph: label_ids, is_training_ph: True, keep_prob_ph: 0.9 }) train_metric_val = sess.run(metric_val) train_summary_str = sess.run( train_summary_total, feed_dict={ ph: value for ph, value in zip(metric_phs, train_metric_val) }) writer.add_summary(train_summary_str, step) if step % 100 == 0 or step % num_batch == 0 or step == 1: # evaluate on dev set if num_actual_eval_examples <= 1000: sess.run(metric_vars_initializer) sess.run(metric_op, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, is_training_ph: False, keep_prob_ph: 1 }) eval_metric_val = sess.run(metric_val) eval_summary_str = sess.run( eval_summary_total, feed_dict={ ph: value for ph, value in zip( metric_phs, eval_metric_val) }) else: num_batch_eval = num_actual_eval_examples // batch_size \ if num_actual_eval_examples % batch_size == 0 \ else num_actual_eval_examples // batch_size + 1 id_eval = 0 sess.run(metric_vars_initializer) for _ in range(num_batch_eval): eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer, train_idx_offset=id_eval) id_eval += batch_size sess.run(metric_op, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, is_training_ph: False, keep_prob_ph: 1 }) eval_metric_val = sess.run(metric_val) eval_summary_str = sess.run( eval_summary_total, feed_dict={ ph: value for ph, value in zip( metric_phs, eval_metric_val) }) writer_eval.add_summary(eval_summary_str, step) if step == 1: for key, val in zip(metric_name, eval_metric_val): start_metric["eval_{}".format(key)] = val if step % 100 == 0 or step % num_batch == 0 or step == 1: train_metric_list = [] for i in range(len(train_metric_val)): if metric_name[i] == 'loss': train_metric_list.append( "{}: %2.4f".format(metric_name[i]) % train_metric_val[i]) else: train_metric_list.append( "{}: %.4f".format(metric_name[i]) % train_metric_val[i]) train_str = 'Train ' + '|'.join(train_metric_list) eval_metric_list = [] for i in range(len(eval_metric_val)): if metric_name[i] == 'loss': eval_metric_list.append( "{}: %2.4f".format(metric_name[i]) % eval_metric_val[i]) else: eval_metric_list.append( "{}: %.4f".format(metric_name[i]) % eval_metric_val[i]) eval_str = 'Eval ' + '|'.join(eval_metric_list) print( "Swap layer order {} | Epoch: %4d/%4d | Batch: %4d/%4d | {} | {}" .format(layer_folder_name, train_str, eval_str) % (n, epochs, b, num_batch)) if step % num_batch == 0: saver_all.save(sess, log_dir + 'swap_{}'.format(layer_folder_name), global_step=step) step += 1 writer.close() writer_eval.close() end_metric = {"eval_{}".format(key): 0 for key in metric_name} if FLAGS.do_eval: tf.logging.info("***** Run evaluation *****") if num_actual_eval_examples <= 1000: sess.run(metric_vars_initializer) sess.run(metric_op, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, is_training_ph: False, keep_prob_ph: 1 }) eval_metric_val = sess.run(metric_val) preds = sess.run(predictions, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, is_training_ph: False, keep_prob_ph: 1 }) eval_label_ids_lst = eval_label_ids else: num_batch_eval = num_actual_eval_examples // batch_size \ if num_actual_eval_examples % batch_size == 0 \ else num_actual_eval_examples // batch_size + 1 id_eval = 0 preds = np.zeros(num_actual_eval_examples) eval_label_ids_lst = np.zeros(num_actual_eval_examples) sess.run(metric_vars_initializer) for i in range(num_batch_eval): eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer, train_idx_offset=id_eval) id_eval += batch_size sess.run(metric_op, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, is_training_ph: False, keep_prob_ph: 1 }) pred = sess.run(predictions, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, is_training_ph: False, keep_prob_ph: 1 }) preds[i * batch_size:min(id_eval, num_actual_eval_examples )] = pred[:] eval_label_ids_lst[i * batch_size:min( id_eval, num_actual_eval_examples)] = eval_label_ids[:] eval_metric_val = sess.run(metric_val) for key, val in zip(metric_name, eval_metric_val): end_metric["eval_{}".format(key)] = val output_predict_file = os.path.join(log_dir, 'dev_predictions.tsv') writer_output = tf.io.gfile.GFile(output_predict_file, "w") preds = preds.astype(int) eval_label_ids_lst = eval_label_ids_lst.astype(int) num_written_lines = 0 if task_name != 'stsb': writer_output.write( "ID \t Label ID \t Label \t Ground Truth ID \t Ground Truth \n" ) else: writer_output.write("ID \t Label \n") for (i, pred) in enumerate(preds): if task_name != 'stsb': writer_output.write("{} \t {} \t {} \t {} \t {} \n".format( i, pred, label_list[pred], eval_label_ids_lst[i], label_list[eval_label_ids_lst[i]])) else: writer_output.write("{} \t {} \n".format( num_written_lines, pred)) writer_output.close() tf.logging.info("***** Finish writing *****") print("Start metric", start_metric) print("End metric", end_metric) test_metric = {"test_{}".format(key): 0 for key in metric_name} if FLAGS.do_pred: # if number of test examples < 1000, just load it directly, or load by batch. # prediction tf.logging.info("***** Predict results *****") if num_actual_test_examples <= 1000: test_input_ids, test_input_mask, test_segment_ids, \ test_label_ids, test_is_real_example = generate_ph_input(batch_size=num_actual_test_examples, seq_length=seq_length, examples=test_examples, label_list=label_list, tokenizer=tokenizer) sess.run(metric_vars_initializer) sess.run(metric_op, feed_dict={ input_ids_ph: test_input_ids, input_mask_ph: test_input_mask, segment_ids_ph: test_segment_ids, label_ids_ph: test_label_ids, is_training_ph: False, keep_prob_ph: 1 }) test_metric_val = sess.run(metric_val) preds = sess.run(predictions, feed_dict={ input_ids_ph: test_input_ids, input_mask_ph: test_input_mask, segment_ids_ph: test_segment_ids, label_ids_ph: test_label_ids, is_training_ph: False, keep_prob_ph: 1 }) test_label_ids_lst = test_label_ids else: num_batch_test = num_actual_test_examples // batch_size \ if num_actual_test_examples % batch_size == 0 \ else num_actual_test_examples // batch_size + 1 id_test = 0 preds = np.zeros(num_actual_test_examples) test_label_ids_lst = np.zeros(num_actual_test_examples) sess.run(metric_vars_initializer) for i in range(num_batch_test): test_input_ids, test_input_mask, test_segment_ids, \ test_label_ids, test_is_real_example = generate_ph_input(batch_size=batch_size, seq_length=seq_length, examples=test_examples, label_list=label_list, tokenizer=tokenizer, train_idx_offset=id_test) id_test += batch_size sess.run(metric_op, feed_dict={ input_ids_ph: test_input_ids, input_mask_ph: test_input_mask, segment_ids_ph: test_segment_ids, label_ids_ph: test_label_ids, is_training_ph: False, keep_prob_ph: 1 }) pred = sess.run(predictions, feed_dict={ input_ids_ph: test_input_ids, input_mask_ph: test_input_mask, segment_ids_ph: test_segment_ids, label_ids_ph: test_label_ids, is_training_ph: False, keep_prob_ph: 1 }) preds[i * batch_size:min(id_test, num_actual_test_examples )] = pred[:] test_label_ids_lst[i * batch_size:min( id_test, num_actual_test_examples)] = test_label_ids[:] test_metric_val = sess.run(metric_val) for key, val in zip(metric_name, test_metric_val): test_metric["test_{}".format(key)] = val output_predict_file = os.path.join(log_dir, 'test_predictions.tsv') submit_predict_file = os.path.join( log_dir, "{}.tsv".format(standard_file_name[task_name])) writer_output = tf.io.gfile.GFile(output_predict_file, "w") writer_submit = tf.io.gfile.GFile(submit_predict_file, 'w') preds = preds.astype(int) test_label_ids_lst = test_label_ids_lst.astype(int) num_written_lines = 0 if task_name != 'stsb': writer_output.write( "ID \t Label ID \t Label \t Ground Truth ID \t Ground Truth \n" ) else: writer_output.write("ID \t Label \n") writer_submit.write("ID \t Label \n") for (i, pred) in enumerate(preds): if task_name != 'stsb': writer_output.write("{} \t {} \t {} \t {} \t {} \n".format( i, pred, label_list[pred], test_label_ids_lst[i], label_list[test_label_ids_lst[i]])) writer_submit.write("{} \t {} \n".format( i, label_list[pred])) else: writer_output.write("{} \t {} \n".format( num_written_lines, pred)) writer_submit.write("{} \t {} \n".format(i, pred)) writer_output.close() writer_submit.close() tf.logging.info("***** Finish writing *****") with tf.io.gfile.GFile(FLAGS.output_dir + 'results.txt', 'a') as writer: eval_start, eval_end, test_end = [], [], [] for metric in metric_name: if metric != 'loss': eval_start.append("{}: %.4f".format(metric) % start_metric["eval_{}".format(metric)]) eval_end.append("{}: %.4f".format(metric) % end_metric["eval_{}".format(metric)]) test_end.append("{}: %.4f".format(metric) % test_metric["test_{}".format(metric)]) writer.write( "Swap layer order {}: Dev start: {} | Dev end: {} | Test end: {}\n" .format(layer_folder_name, ','.join(eval_start), ','.join(eval_end), ','.join(test_end)))
def main(_): tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) processors = { "cola": ColaProcessor, "mnlim": MnliMProcessor, "mnlimm": MnliMMProcessor, "mrpc": MrpcProcessor, "qnli": QnliProcessor, "qqp": QqpProcessor, "rte": RteProcessor, "sst2": Sst2Processor, "stsb": StsbProcessor, "wnli": WnliProcessor, "ax": AxProcessor } tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, FLAGS.init_checkpoint) if not FLAGS.do_eval: raise ValueError("At least 'do_eval' must be True.") bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) if FLAGS.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (FLAGS.max_seq_length, bert_config.max_position_embeddings)) tf.io.gfile.makedirs(FLAGS.output_dir) task_name = FLAGS.task_name.lower() print("Current task", task_name) if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() label_list = processor.get_labels() print("Label list of current task", label_list) tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) train_examples = processor.get_train_examples(FLAGS.data_dir) eval_examples = processor.get_dev_examples(FLAGS.data_dir) num_actual_train_examples = len(train_examples) num_actual_eval_examples = len(eval_examples) print("num_actual_train_examples", num_actual_train_examples) print("num_actual_eval_examples", num_actual_eval_examples) batch_size = FLAGS.train_batch_size embed_dim = FLAGS.hidden_size # hidden size, 768 for BERT-base, 512 for BERT-small seq_length = FLAGS.max_seq_length num_labels = len(label_list) # Define some placeholders for the input input_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_ids') input_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_mask') segment_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='segment_ids') label_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, ], name='label_ids') tf.compat.v1.logging.info("Running cumulatively mask out heads and direct evaluation!") # we want to mask out the individual head and then evaluate. So there are 12 layers * 12 heads results. n_layers = 12 n_heads = 12 folder = FLAGS.output_dir save_file = 'cumulative_heads_mask.pickle' output = np.zeros((n_layers, n_heads)) importance_coordinates = [] if FLAGS.importance_setting == 'l2_norm': ####################### # 12 * 12, layer * head, importance increases from 1 to 144. # [[127, 63, 67, 72, 91, 93, 124, 100, 96, 134, 15, 133], # [143, 60, 107, 128, 57, 106, 118, 83, 144, 135, 99, 116], # [30, 111, 115, 132, 112, 94, 122, 66, 123, 40, 75, 89], # [108, 140, 141, 119, 121, 62, 137, 131, 125, 70, 85, 105], # [126, 120, 113, 136, 92, 79, 110, 74, 103, 84, 86, 53], # [69, 87, 117, 80, 142, 114, 129, 104, 97, 18, 52, 77], # [81, 88, 48, 90, 56, 50, 58, 101, 130, 64, 35, 46], # [20, 41, 38, 32, 71, 59, 82, 43, 78, 55, 47, 37], # [24, 95, 27, 44, 65, 12, 28, 102, 98, 23, 14, 19], # [17, 16, 2, 1, 9, 13, 68, 4, 139, 7, 21, 109], # [76, 26, 8, 138, 10, 29, 31, 54, 6, 36, 3, 49], # [61, 51, 42, 45, 34, 5, 73, 33, 25, 22, 39, 11]] ####################### # sorted head coordinate array by importance (L2 weight magnitude), (layer, head) # least 1 --> most 144 importance_coordinates = [[9, 3], [9, 2], [10, 10], [9, 7], [11, 5], [10, 8], [9, 9], [10, 2], [9, 4], [10, 4], [11, 11], [8, 5], [9, 5], [8, 10], [0, 10], [9, 1], [9, 0], [5, 9], [8, 11], [7, 0], [9, 10], [11, 9], [8, 9], [8, 0], [11, 8], [10, 1], [8, 2], [8, 6], [10, 5], [2, 0], [10, 6], [7, 3], [11, 7], [11, 4], [6, 10], [10, 9], [7, 11], [7, 2], [11, 10], [2, 9], [7, 1], [11, 2], [7, 7], [8, 3], [11, 3], [6, 11], [7, 10], [6, 2], [10, 11], [6, 5], [11, 1], [5, 10], [4, 11], [10, 7], [7, 9], [6, 4], [1, 4], [6, 6], [7, 5], [1, 1], [11, 0], [3, 5], [0, 1], [6, 9], [8, 4], [2, 7], [0, 2], [9, 6], [5, 0], [3, 9], [7, 4], [0, 3], [11, 6], [4, 7], [2, 10], [10, 0], [5, 11], [7, 8], [4, 5], [5, 3], [6, 0], [7, 6], [1, 7], [4, 9], [3, 10], [4, 10], [5, 1], [6, 1], [2, 11], [6, 3], [0, 4], [4, 4], [0, 5], [2, 5], [8, 1], [0, 8], [5, 8], [8, 8], [1, 10], [0, 7], [6, 7], [8, 7], [4, 8], [5, 7], [3, 11], [1, 5], [1, 2], [3, 0], [9, 11], [4, 6], [2, 1], [2, 4], [4, 2], [5, 5], [2, 2], [1, 11], [5, 2], [1, 6], [3, 3], [4, 1], [3, 4], [2, 6], [2, 8], [0, 6], [3, 8], [4, 0], [0, 0], [1, 3], [5, 6], [6, 8], [3, 7], [2, 3], [0, 11], [0, 9], [1, 9], [4, 3], [3, 6], [10, 3], [9, 8], [3, 1], [3, 2], [5, 4], [1, 0], [1, 8]] elif FLAGS.importance_setting == 'per_head_score': ####################### # 12 * 12, layer * head, importance increases from 1 to 144. # [[127, 63, 67, 72, 91, 93, 124, 100, 96, 134, 15, 133], # [143, 60, 107, 128, 57, 106, 118, 83, 144, 135, 99, 116], # [30, 111, 115, 132, 112, 94, 122, 66, 123, 40, 75, 89], # [108, 140, 141, 119, 121, 62, 137, 131, 125, 70, 85, 105], # [126, 120, 113, 136, 92, 79, 110, 74, 103, 84, 86, 53], # [69, 87, 117, 80, 142, 114, 129, 104, 97, 18, 52, 77], # [81, 88, 48, 90, 56, 50, 58, 101, 130, 64, 35, 46], # [20, 41, 38, 32, 71, 59, 82, 43, 78, 55, 47, 37], # [24, 95, 27, 44, 65, 12, 28, 102, 98, 23, 14, 19], # [17, 16, 2, 1, 9, 13, 68, 4, 139, 7, 21, 109], # [76, 26, 8, 138, 10, 29, 31, 54, 6, 36, 3, 49], # [61, 51, 42, 45, 34, 5, 73, 33, 25, 22, 39, 11]] ####################### # sorted head coordinate array by importance (L2 weight magnitude), (layer, head) # least 1 --> most 144 importance_coordinates = [[9, 3], [9, 2], [10, 10], [9, 7], [11, 5], [10, 8], [9, 9], [10, 2], [9, 4], [10, 4], [11, 11], [8, 5], [9, 5], [8, 10], [0, 10], [9, 1], [9, 0], [5, 9], [8, 11], [7, 0], [9, 10], [11, 9], [8, 9], [8, 0], [11, 8], [10, 1], [8, 2], [8, 6], [10, 5], [2, 0], [10, 6], [7, 3], [11, 7], [11, 4], [6, 10], [10, 9], [7, 11], [7, 2], [11, 10], [2, 9], [7, 1], [11, 2], [7, 7], [8, 3], [11, 3], [6, 11], [7, 10], [6, 2], [10, 11], [6, 5], [11, 1], [5, 10], [4, 11], [10, 7], [7, 9], [6, 4], [1, 4], [6, 6], [7, 5], [1, 1], [11, 0], [3, 5], [0, 1], [6, 9], [8, 4], [2, 7], [0, 2], [9, 6], [5, 0], [3, 9], [7, 4], [0, 3], [11, 6], [4, 7], [2, 10], [10, 0], [5, 11], [7, 8], [4, 5], [5, 3], [6, 0], [7, 6], [1, 7], [4, 9], [3, 10], [4, 10], [5, 1], [6, 1], [2, 11], [6, 3], [0, 4], [4, 4], [0, 5], [2, 5], [8, 1], [0, 8], [5, 8], [8, 8], [1, 10], [0, 7], [6, 7], [8, 7], [4, 8], [5, 7], [3, 11], [1, 5], [1, 2], [3, 0], [9, 11], [4, 6], [2, 1], [2, 4], [4, 2], [5, 5], [2, 2], [1, 11], [5, 2], [1, 6], [3, 3], [4, 1], [3, 4], [2, 6], [2, 8], [0, 6], [3, 8], [4, 0], [0, 0], [1, 3], [5, 6], [6, 8], [3, 7], [2, 3], [0, 11], [0, 9], [1, 9], [4, 3], [3, 6], [10, 3], [9, 8], [3, 1], [3, 2], [5, 4], [1, 0], [1, 8]] # mask out heads from the most important one if FLAGS.from_most: importance_coordinates.reverse() # two placeholders for the head coordinates, layer, head head_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, ], name='head_mask') layer_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, ], name='layer_mask') model = modeling.BertModel( config=bert_config, is_training=False, input_ids=input_ids_ph, # input_ids, input_mask=input_mask_ph, # input_mask, token_type_ids=segment_ids_ph, # segment_ids, use_one_hot_embeddings=False, head_mask=head_mask_ph, layer_mask=layer_mask_ph) output_layer = model.get_pooled_output() output_weights = tf.get_variable( "output_weights", [num_labels, embed_dim], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable( "output_bias", [num_labels], initializer=tf.zeros_initializer()) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) with tf.compat.v1.variable_scope("loss"): # for stsb if num_labels == 1: logits = tf.squeeze(logits, [-1]) per_example_loss = tf.square(logits - label_ids_ph) loss = tf.reduce_mean(per_example_loss) else: log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids_ph, depth=num_labels, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) loss = tf.reduce_mean(per_example_loss) predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) # metric and summary # metric is tf.metric object, (val, op) metric = metric_fn(per_example_loss, label_ids_ph, logits, num_labels, task_name) metric_name = list(metric.keys()) metric_val = [m[0] for m in metric.values()] metric_op = [m[1] for m in metric.values()] init_checkpoint = FLAGS.init_checkpoint tvars = tf.compat.v1.trainable_variables() saver_init = tf.train.Saver(tvars) # Isolate the variables stored behind the scenes by the metric operation var_metric = [] for key in metric.keys(): var_metric.extend(tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=key)) # Define initializer to initialize/reset running variables metric_vars_initializer = tf.variables_initializer(var_list=var_metric) config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True with tf.compat.v1.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver_init.restore(sess, init_checkpoint) # if number of eval examples < 1000, just load it directly, or load by batch. if num_actual_eval_examples <= 1000: eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=num_actual_eval_examples, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer) # loop over layers, then loop over heads idx = 0 for l in range(n_layers): for h in range(n_heads): cur_l, cur_h = importance_coordinates[idx] coor_mask = importance_coordinates[0:idx + 1] head_mask = [head for _, head in coor_mask] layer_mask = [layer for layer, _ in coor_mask] # if number of eval examples < 1000, just load it directly, or load by batch. if num_actual_eval_examples <= 1000: sess.run(metric_vars_initializer) sess.run(metric_op, feed_dict={input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, head_mask_ph: head_mask, layer_mask_ph: layer_mask}) eval_metric_val = sess.run(metric_val) else: num_batch_eval = num_actual_eval_examples // batch_size \ if num_actual_eval_examples % batch_size == 0 \ else num_actual_eval_examples // batch_size + 1 id_eval = 0 sess.run(metric_vars_initializer) for _ in range(num_batch_eval): eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer, train_idx_offset=id_eval) id_eval += batch_size sess.run(metric_op, feed_dict={input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, head_mask_ph: head_mask, layer_mask_ph: layer_mask}) eval_metric_val = sess.run(metric_val) for name, val in zip(metric_name, eval_metric_val): if name == 'accuracy': output[cur_l][cur_h] = val print("Mask out the {}th head in (Layer {}, Head {}) | {}: {}" .format(idx + 1, cur_l, cur_h, name, val)) idx += 1 joblib.dump(output, folder + save_file)
def main(_): tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) processors = { "cola": ColaProcessor, "mnlim": MnliMProcessor, "mnlimm": MnliMMProcessor, "mrpc": MrpcProcessor, "qnli": QnliProcessor, "qqp": QqpProcessor, "rte": RteProcessor, "sst2": Sst2Processor, "stsb": StsbProcessor, "wnli": WnliProcessor, "ax": AxProcessor, "mnlimdevastest": MnliMDevAsTestProcessor } tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, FLAGS.init_checkpoint) if not FLAGS.do_train: raise ValueError("At least 'do_train' must be True.") bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) if FLAGS.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (FLAGS.max_seq_length, bert_config.max_position_embeddings)) tf.io.gfile.makedirs(FLAGS.output_dir) task_name = FLAGS.task_name.lower() print("Current task", task_name) if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() # special handling for mnlimdevastest if task_name == 'mnlimdevastest': task_name = 'mnlim' label_list = processor.get_labels() print("Label list of current task", label_list) tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) train_examples = processor.get_train_examples(FLAGS.data_dir) eval_examples = processor.get_dev_examples(FLAGS.data_dir) num_actual_train_examples = len(train_examples) num_actual_eval_examples = len(eval_examples) print("num_actual_train_examples", num_actual_train_examples) print("num_actual_eval_examples", num_actual_eval_examples) batch_size = FLAGS.train_batch_size epochs = FLAGS.num_train_epochs embed_dim = FLAGS.hidden_size # hidden size, 768 for BERT-base, 512 for BERT-small seq_length = FLAGS.max_seq_length num_labels = len(label_list) # Define some placeholders for the input input_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_ids') input_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='input_mask') segment_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[None, seq_length], name='segment_ids') label_ids_ph = tf.compat.v1.placeholder(tf.int32, shape=[ None, ], name='label_ids') tf.compat.v1.logging.info( "Running leave the most important head per layer then fine-tune!") num_train_steps = num_actual_train_examples // batch_size * epochs num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) cur_layer = FLAGS.cur_layer most_important_head = FLAGS.most_important_head print("Current most important head:", cur_layer, most_important_head) # this placeholder is to control the flag for the dropout keep_prob_ph = tf.compat.v1.placeholder(tf.float32, name="keep_prob") is_training_ph = tf.compat.v1.placeholder(tf.bool, name='is_training') # two placeholders for the head coordinates, layer, head head_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[ None, ], name='head_mask') layer_mask_ph = tf.compat.v1.placeholder(tf.int32, shape=[ None, ], name='layer_mask') model = modeling.BertModel( config=bert_config, is_training=is_training_ph, input_ids=input_ids_ph, # input_ids, input_mask=input_mask_ph, # input_mask, token_type_ids=segment_ids_ph, # segment_ids, use_one_hot_embeddings=False, use_estimator=False, head_mask=head_mask_ph, layer_mask=layer_mask_ph) output_layer = model.get_pooled_output() output_layer = tf.nn.dropout(output_layer, keep_prob=keep_prob_ph) output_weights = tf.get_variable( "output_weights", [num_labels, embed_dim], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable("output_bias", [num_labels], initializer=tf.zeros_initializer()) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) with tf.compat.v1.variable_scope("loss"): # for stsb if num_labels == 1: logits = tf.squeeze(logits, [-1]) per_example_loss = tf.square(logits - label_ids_ph) loss = tf.reduce_mean(per_example_loss) else: log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids_ph, depth=num_labels, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) loss = tf.reduce_mean(per_example_loss) predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) # metric and summary # metric is tf.metric object, (val, op) metric = metric_fn(per_example_loss, label_ids_ph, logits, num_labels, task_name) metric_name = list(metric.keys()) metric_val = [m[0] for m in metric.values()] metric_op = [m[1] for m in metric.values()] log_dir = FLAGS.output_dir + 'layer_{}_head_{}/'.format( cur_layer, most_important_head) metric_phs = [ tf.compat.v1.placeholder(tf.float32, name="{}_ph".format(key)) for key in metric.keys() ] summaries = [ tf.compat.v1.summary.scalar(key, metric_phs[i]) for i, key in enumerate(metric.keys()) ] train_summary_total = tf.summary.merge(summaries) eval_summary_total = tf.summary.merge(summaries) init_checkpoint = FLAGS.init_checkpoint tvars = tf.compat.v1.trainable_variables() var_init = [ v for v in tvars if 'output_weights' not in v.name and 'output_bias' not in v.name ] var_output = [ v for v in tvars if 'output_weights' in v.name or "output_bias" in v.name ] if not FLAGS.load_from_finetuned: # Init from Model0 saver_init = tf.train.Saver(var_init) else: # Init from Model1 saver_init = tf.train.Saver(var_init + var_output) var_train = var_init + var_output print("Training parameters") for v in var_train: print(v) train_op = optimization.create_optimizer(loss=loss, init_lr=FLAGS.learning_rate, num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, use_tpu=False, tvars=var_train) saver_all = tf.train.Saver(var_list=var_init + var_output, max_to_keep=1) # Isolate the variables stored behind the scenes by the metric operation var_metric = [] for key in metric.keys(): var_metric.extend( tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=key)) # Define initializer to initialize/reset running variables metric_vars_initializer = tf.variables_initializer(var_list=var_metric) config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True with tf.compat.v1.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver_init.restore(sess, init_checkpoint) writer = tf.compat.v1.summary.FileWriter(log_dir + 'log/train/', sess.graph) writer_eval = tf.compat.v1.summary.FileWriter(log_dir + 'log/eval/') # heads need to be masked out in cur_layer head_mask = [head for head in range(12) if head != most_important_head] layer_mask = [cur_layer for _ in range(11)] # if number of eval examples < 1000, just load it directly, or load by batch. if num_actual_eval_examples <= 1000: eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=num_actual_eval_examples, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer) start_metric = {"eval_{}".format(key): 0 for key in metric_name} end_metric = {"eval_{}".format(key): 0 for key in metric_name} if FLAGS.do_train: tf.logging.info("***** Run training *****") step = 1 for n in range(epochs): np.random.shuffle(train_examples) num_batch = num_actual_train_examples // batch_size if num_actual_train_examples % batch_size == 0 \ else num_actual_train_examples // batch_size + 1 id = 0 for b in range(num_batch): input_ids, input_mask, \ segment_ids, label_ids, is_real_example = generate_ph_input(batch_size=batch_size, seq_length=seq_length, examples=train_examples, label_list=label_list, tokenizer=tokenizer, train_idx_offset=id) id += batch_size sess.run(metric_vars_initializer) sess.run([train_op] + metric_op, feed_dict={ input_ids_ph: input_ids, input_mask_ph: input_mask, segment_ids_ph: segment_ids, label_ids_ph: label_ids, is_training_ph: True, keep_prob_ph: 0.9, head_mask_ph: head_mask, layer_mask_ph: layer_mask }) train_metric_val = sess.run(metric_val) train_summary_str = sess.run( train_summary_total, feed_dict={ ph: value for ph, value in zip(metric_phs, train_metric_val) }) writer.add_summary(train_summary_str, step) if step % 100 == 0 or step % num_batch == 0 or step == 1: # evaluate on dev set if num_actual_eval_examples <= 1000: sess.run(metric_vars_initializer) sess.run(metric_op, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, is_training_ph: False, keep_prob_ph: 1, head_mask_ph: head_mask, layer_mask_ph: layer_mask }) eval_metric_val = sess.run(metric_val) eval_summary_str = sess.run( eval_summary_total, feed_dict={ ph: value for ph, value in zip( metric_phs, eval_metric_val) }) else: num_batch_eval = num_actual_eval_examples // batch_size \ if num_actual_eval_examples % batch_size == 0 \ else num_actual_eval_examples // batch_size + 1 id_eval = 0 sess.run(metric_vars_initializer) for _ in range(num_batch_eval): eval_input_ids, eval_input_mask, eval_segment_ids, \ eval_label_ids, eval_is_real_example = generate_ph_input(batch_size=batch_size, seq_length=seq_length, examples=eval_examples, label_list=label_list, tokenizer=tokenizer, train_idx_offset=id_eval) id_eval += batch_size sess.run(metric_op, feed_dict={ input_ids_ph: eval_input_ids, input_mask_ph: eval_input_mask, segment_ids_ph: eval_segment_ids, label_ids_ph: eval_label_ids, is_training_ph: False, keep_prob_ph: 1, head_mask_ph: head_mask, layer_mask_ph: layer_mask }) eval_metric_val = sess.run(metric_val) eval_summary_str = sess.run( eval_summary_total, feed_dict={ ph: value for ph, value in zip( metric_phs, eval_metric_val) }) writer_eval.add_summary(eval_summary_str, step) if step == 1: for key, val in zip(metric_name, eval_metric_val): start_metric["eval_{}".format(key)] = val if step == epochs * num_batch: for key, val in zip(metric_name, eval_metric_val): end_metric["eval_{}".format(key)] = val if step % 100 == 0 or step % num_batch == 0 or step == 1: train_metric_list = [] for i in range(len(train_metric_val)): if metric_name[i] == 'loss': train_metric_list.append( "{}: %2.4f".format(metric_name[i]) % train_metric_val[i]) else: train_metric_list.append( "{}: %.4f".format(metric_name[i]) % train_metric_val[i]) train_str = 'Train ' + '|'.join(train_metric_list) eval_metric_list = [] for i in range(len(eval_metric_val)): if metric_name[i] == 'loss': eval_metric_list.append( "{}: %2.4f".format(metric_name[i]) % eval_metric_val[i]) else: eval_metric_list.append( "{}: %.4f".format(metric_name[i]) % eval_metric_val[i]) eval_str = 'Eval ' + '|'.join(eval_metric_list) print( "Layer {}, leave only one head {} | Epoch: %4d/%4d | Batch: %4d/%4d | {} | {}" .format(cur_layer, most_important_head, train_str, eval_str) % (n, epochs, b, num_batch)) if step % num_batch == 0: saver_all.save(sess, log_dir + 'layer{}_head_{}'.format( cur_layer, most_important_head), global_step=step) step += 1 writer.close() writer_eval.close() print("Start metric", start_metric) print("End metric", end_metric) with tf.io.gfile.GFile(FLAGS.output_dir + 'results.txt', 'a') as writer: eval_start, eval_end = [], [] for metric in metric_name: if metric != 'loss': eval_start.append("{}: %.4f".format(metric) % start_metric["eval_{}".format(metric)]) eval_end.append("{}: %.4f".format(metric) % end_metric["eval_{}".format(metric)]) writer.write( "Layer {}, leave only one head {}: Start: {} | End: {}\n". format(cur_layer, most_important_head, ','.join(eval_start), ','.join(eval_end)))