def create_model( bert_config, is_training, input_ids, input_mask, input_type_ids, labels, num_labels, use_one_hot_embeddings, tsa, unsup_ratio, global_step, num_train_steps, ): num_sample = input_ids.shape[0].value if is_training: assert num_sample % (1 + 2 * unsup_ratio) == 0 sup_batch_size = num_sample // (1 + 2 * unsup_ratio) unsup_batch_size = sup_batch_size * unsup_ratio else: sup_batch_size = num_sample unsup_batch_size = 0 pooled = modeling.bert_model( config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=input_type_ids, use_one_hot_embeddings=use_one_hot_embeddings) clas_logits = hidden_to_logits( hidden=pooled, is_training=is_training, num_classes=num_labels, scope="classifier") log_probs = tf.nn.log_softmax(clas_logits, axis=-1) correct_label_probs = None with tf.variable_scope("sup_loss"): sup_log_probs = log_probs[:sup_batch_size] one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) tgt_label_prob = one_hot_labels per_example_loss = -tf.reduce_sum(tgt_label_prob * sup_log_probs, axis=-1) loss_mask = tf.ones_like(per_example_loss, dtype=per_example_loss.dtype) correct_label_probs = tf.reduce_sum( one_hot_labels * tf.exp(sup_log_probs), axis=-1) if tsa: tsa_start = 1. / num_labels tsa_threshold = get_tsa_threshold( tsa, global_step, num_train_steps, tsa_start, end=1) larger_than_threshold = tf.greater( correct_label_probs, tsa_threshold) loss_mask = loss_mask * (1 - tf.cast(larger_than_threshold, tf.float32)) else: tsa_threshold = 1 loss_mask = tf.stop_gradient(loss_mask) per_example_loss = per_example_loss * loss_mask sup_loss = (tf.reduce_sum(per_example_loss) / tf.maximum(tf.reduce_sum(loss_mask), 1)) unsup_loss_mask = None if is_training and unsup_ratio > 0: with tf.variable_scope("unsup_loss"): ori_start = sup_batch_size ori_end = ori_start + unsup_batch_size aug_start = sup_batch_size + unsup_batch_size aug_end = aug_start + unsup_batch_size ori_log_probs = log_probs[ori_start : ori_end] aug_log_probs = log_probs[aug_start : aug_end] unsup_loss_mask = 1 if FLAGS.uda_softmax_temp != -1: tgt_ori_log_probs = tf.nn.log_softmax( clas_logits[ori_start : ori_end] / FLAGS.uda_softmax_temp, axis=-1) tgt_ori_log_probs = tf.stop_gradient(tgt_ori_log_probs) else: tgt_ori_log_probs = tf.stop_gradient(ori_log_probs) if FLAGS.uda_confidence_thresh != -1: largest_prob = tf.reduce_max(tf.exp(ori_log_probs), axis=-1) unsup_loss_mask = tf.cast(tf.greater( largest_prob, FLAGS.uda_confidence_thresh), tf.float32) unsup_loss_mask = tf.stop_gradient(unsup_loss_mask) per_example_kl_loss = kl_for_log_probs( tgt_ori_log_probs, aug_log_probs) * unsup_loss_mask unsup_loss = tf.reduce_mean(per_example_kl_loss) else: unsup_loss = 0. return (sup_loss, unsup_loss, clas_logits[:sup_batch_size], per_example_loss, loss_mask, tsa_threshold, unsup_loss_mask, correct_label_probs)
def create_model( bert_config, is_training, input_ids, input_mask, input_type_ids, labels, num_labels, use_one_hot_embeddings, tsa, unsup_ratio, global_step, num_train_steps, ): num_sample = input_ids.shape[0].value if is_training: assert num_sample % (1 + 2 * unsup_ratio) == 0 sup_batch_size = num_sample // (1 + 2 * unsup_ratio) unsup_batch_size = sup_batch_size * unsup_ratio else: sup_batch_size = num_sample unsup_batch_size = 0 sequence = modeling.bert_model( config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=input_type_ids, use_one_hot_embeddings=use_one_hot_embeddings, output_type="sequence") clas_logits = hidden_to_logits(hidden=sequence, is_training=is_training, num_classes=num_labels, scope="classifier") log_probs = tf.nn.log_softmax(clas_logits, axis=-1) correct_label_probs = None with tf.variable_scope("sup_loss"): sup_log_probs = log_probs[:sup_batch_size] tf.logging.info("***************%d", sup_batch_size) one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) tgt_label_prob = one_hot_labels per_example_loss = -tf.reduce_sum(tgt_label_prob * sup_log_probs, axis=-1) loss_mask = tf.ones_like(per_example_loss, dtype=per_example_loss.dtype) correct_label_probs = tf.reduce_sum(one_hot_labels * tf.exp(sup_log_probs), axis=-1) if tsa: tsa_start = 1. / num_labels tsa_threshold = get_tsa_threshold(tsa, global_step, num_train_steps, tsa_start, end=1) larger_than_threshold = tf.greater(correct_label_probs, tsa_threshold) loss_mask = loss_mask * ( 1 - tf.cast(larger_than_threshold, tf.float32)) else: tsa_threshold = 1 loss_mask = tf.stop_gradient(loss_mask) per_example_loss = per_example_loss * loss_mask sup_loss = (tf.reduce_sum(per_example_loss) / tf.maximum(tf.reduce_sum(loss_mask), 1)) unsup_loss_mask = None if unsup_ratio > 0: with tf.variable_scope("unsup_loss"): ori_start = sup_batch_size ori_end = ori_start + unsup_batch_size aug_start = sup_batch_size + unsup_batch_size aug_end = aug_start + unsup_batch_size ori_log_probs = log_probs[ori_start:ori_end] aug_log_probs_before = log_probs[aug_start:aug_end] ########################################## #kick out B-TRI and E-TRI tag, rematch _, max_seq_length, hidden_dim = clas_logits.get_shape().as_list() aug_input_type_ids = input_type_ids[aug_start:aug_end] one = tf.where(tf.equal(aug_input_type_ids, 1)) zero = tf.where(tf.equal(aug_input_type_ids, 0)) p = tf.concat([one, zero], 0) def my_numpy_func(x): return x[x[:, 0].argsort()] aug_input_type_ids_trans = tf.py_func(my_numpy_func, [p], Tout=tf.int64) aug_log_probs_middle = tf.gather_nd( params=aug_log_probs_before, indices=aug_input_type_ids_trans) aug_log_probs = tf.reshape(aug_log_probs_middle, [-1, max_seq_length, hidden_dim]) #vec = tf.to_int32(tf.expand_dims(np.arange(unsup_batch_size)*max_seq_length,-1)) #trans=tf.tile(vec, [1,max_seq_length]) #aug_input_type_ids_trans=tf.reshape(aug_input_type_ids+trans,[-1]) #aug_input_type_ids_trans = tf.stop_gradient(aug_input_type_ids_trans) #aug_log_probs_middle = tf.reshape(aug_log_probs_before,[-1,hidden_dim]) #aug_log_probs = tf.reshape(tf.gather(params=aug_log_probs_middle, indices=aug_input_type_ids_trans), [-1,max_seq_length,hidden_dim]) ######################################### unsup_loss_mask = 1 if FLAGS.uda_softmax_temp != -1: m_aug_log_probs = tf.nn.log_softmax(aug_log_probs / FLAGS.uda_softmax_temp, axis=-1) tgt_aug_log_probs = tf.stop_gradient(m_aug_log_probs) else: tgt_aug_log_probs = tf.stop_gradient(aug_log_probs) if FLAGS.uda_confidence_thresh != -1: largest_prob = tf.reduce_max(tf.exp(ori_log_probs), axis=-1) unsup_loss_mask = tf.cast( tf.greater(largest_prob, FLAGS.uda_confidence_thresh), tf.float32) unsup_loss_mask = tf.stop_gradient(unsup_loss_mask) per_example_kl_loss = kl_for_log_probs( tgt_aug_log_probs, ori_log_probs) * unsup_loss_mask unsup_loss = tf.reduce_mean(per_example_kl_loss) tf.logging.info( "*** **************************************************************************8Features ***" ) else: tf.logging.info( "*** **************************************************************************We have the issue here ***" ) unsup_loss = sup_loss tgt_aug_log_probs = log_probs ori_log_probs = log_probs return (sup_loss, unsup_loss, clas_logits[:sup_batch_size], per_example_loss, loss_mask, tsa_threshold, unsup_loss_mask, correct_label_probs, tgt_aug_log_probs, ori_log_probs)