def get_classification_loss(FLAGS, features, n_class, is_training): import xlnet, modeling """Loss for downstream classification tasks.""" bsz_per_core = tf.shape(features["input_ids"])[0] inp = tf.transpose(features["input_ids"], [1, 0]) seg_id = tf.transpose(features["segment_ids"], [1, 0]) inp_mask = tf.transpose(features["input_mask"], [1, 0]) label = tf.reshape(features["label_ids"], [bsz_per_core]) xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) run_config = xlnet.create_run_config(is_training, True, FLAGS) xlnet_model = xlnet.XLNetModel(xlnet_config=xlnet_config, run_config=run_config, input_ids=inp, seg_ids=seg_id, input_mask=inp_mask) summary = xlnet_model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj) with tf.variable_scope("model", reuse=tf.AUTO_REUSE): if FLAGS.cls_scope is not None and FLAGS.cls_scope: cls_scope = "classification_{}".format(FLAGS.cls_scope) else: cls_scope = "classification_{}".format(FLAGS.task_name.lower()) per_example_loss, logits = modeling.classification_loss( hidden=summary, labels=label, n_class=n_class, initializer=xlnet_model.get_initializer(), scope=cls_scope, return_logits=True) total_loss = tf.reduce_mean(per_example_loss) return total_loss, per_example_loss, logits, summary
def get_classification_loss(args, xlnet_config, features, n_class, is_training=True): """Loss for downstream classification tasks.""" inp = fluid.layers.transpose(features["input_ids"], [1, 0, 2]) seg_id = features["segment_ids"] inp_mask = fluid.layers.transpose(features["input_mask"], [1, 0]) label = features["label_ids"] xlnet_model = XLNetModel(input_ids=inp, seg_ids=seg_id, input_mask=inp_mask, xlnet_config=xlnet_config, args=args) summary = xlnet_model.get_pooled_out(args.summary_type, args.use_summ_proj) per_example_loss, logits = modeling.classification_loss( hidden=summary, labels=label, n_class=n_class, initializer=xlnet_model.get_initializer(), name="model_classification_{}".format(args.task_name), return_logits=True) total_loss = fluid.layers.reduce_mean(per_example_loss) return total_loss, per_example_loss, logits
def __init__(self, FLAGS=FLAGS, n_class=2, is_training=False): import xlnet, modeling import tensorflow as tf init_log() logging.info("Init semantic model ...") self.sp = spm.SentencePieceProcessor() self.sp.Load(FLAGS.spiece_model_file) tf.logging.set_verbosity(tf.logging.INFO) tf_float = tf.bfloat16 if FLAGS.use_bfloat16 else tf.float32 self.input_ids = tf.placeholder(dtype=tf.int64, shape=[None, None], name="input_ids") self.segment_ids = tf.placeholder(dtype=tf.int32, shape=[None, None], name="segment_ids") self.input_mask = tf.placeholder(dtype=tf_float, shape=[None, None], name="input_mask") self.label_ids = tf.placeholder(dtype=tf.int64, shape=[None], name="label_ids") bsz_per_core = tf.shape(self.input_ids)[0] inp = tf.transpose(self.input_ids, [1, 0]) seg_id = tf.transpose(self.segment_ids, [1, 0]) inp_mask = tf.transpose(self.input_mask, [1, 0]) label = tf.reshape(self.label_ids, [bsz_per_core]) self.sess = tf.Session() xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) run_config = xlnet.create_run_config(is_training, True, FLAGS) xlnet_model = xlnet.XLNetModel(xlnet_config=xlnet_config, run_config=run_config, input_ids=inp, seg_ids=seg_id, input_mask=inp_mask) self.summary = xlnet_model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj) with tf.variable_scope("model", reuse=tf.AUTO_REUSE): if FLAGS.cls_scope is not None and FLAGS.cls_scope: cls_scope = "classification_{}".format(FLAGS.cls_scope) else: cls_scope = "classification_{}".format(FLAGS.task_name.lower()) per_example_loss, logits = modeling.classification_loss( hidden=self.summary, labels=label, n_class=n_class, initializer=xlnet_model.get_initializer(), scope=cls_scope, return_logits=True) total_loss = tf.reduce_mean(per_example_loss) num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info('#params: {}'.format(num_params)) xlnet_model.saver.restore(self.sess, FLAGS.init_checkpoint) #### load pretrained models #scaffold_fn = model_utils.init_from_checkpoint(FLAGS) logging.info("Init semantic model finished ...")