def two_stream_loss(FLAGS, features, labels, mems, is_training): """Pretraining loss with two-stream attention Transformer-XL.""" #### Unpack input mem_name = "mems" mems = mems.get(mem_name, None) inp_k = tf.transpose(features["input_k"], [1, 0]) inp_q = tf.transpose(features["input_q"], [1, 0]) seg_id = tf.transpose(features["seg_id"], [1, 0]) inp_mask = None perm_mask = tf.transpose(features["perm_mask"], [1, 2, 0]) if FLAGS.num_predict is not None: # [num_predict x tgt_len x bsz] target_mapping = tf.transpose(features["target_mapping"], [1, 2, 0]) else: target_mapping = None # target for LM loss tgt = tf.transpose(features["target"], [1, 0]) # target mask for LM loss tgt_mask = tf.transpose(features["target_mask"], [1, 0]) # construct xlnet config and save to model_dir xlnet_config = xlnet.XLNetConfig(FLAGS=FLAGS) xlnet_config.to_json(os.path.join(FLAGS.model_dir, "config.json")) # construct run config from FLAGS run_config = xlnet.create_run_config(is_training, False, FLAGS) xlnet_model = xlnet.XLNetModel(xlnet_config=xlnet_config, run_config=run_config, input_ids=inp_k, seg_ids=seg_id, input_mask=inp_mask, mems=mems, perm_mask=perm_mask, target_mapping=target_mapping, inp_q=inp_q) output = xlnet_model.get_sequence_output() new_mems = {mem_name: xlnet_model.get_new_memory()} lookup_table = xlnet_model.get_embedding_table() initializer = xlnet_model.get_initializer() with tf.variable_scope("model", reuse=tf.AUTO_REUSE): # LM loss lm_loss = modeling.lm_loss(hidden=output, target=tgt, n_token=xlnet_config.n_token, d_model=xlnet_config.d_model, initializer=initializer, lookup_table=lookup_table, tie_weight=True, bi_data=run_config.bi_data, use_tpu=run_config.use_tpu) #### Quantity to monitor monitor_dict = {} if FLAGS.use_bfloat16: tgt_mask = tf.cast(tgt_mask, tf.float32) lm_loss = tf.cast(lm_loss, tf.float32) total_loss = tf.reduce_sum(lm_loss * tgt_mask) / tf.reduce_sum(tgt_mask) monitor_dict["total_loss"] = total_loss return total_loss, new_mems, monitor_dict
def __init__(self, model_path): config_file = os.path.join(model_path, 'xlnet_config.json') spiece_file = os.path.join(model_path, 'spiece.model') ckpt_path = os.path.join(model_path, 'xlnet_model.ckpt') self.predict_batch_size = 100 self.max_seq_length = 256 self.max_predictions_per_seq = 20 is_training = False # construct xlnet config and save to model_dir xlnet_config = xlnet.XLNetConfig(json_path=config_file) # construct run config from FLAGS # self.run_config = xlnet.create_run_config(is_training, False, FLAGS) run_config = xlnet.RunConfig(is_training, False, False, 0.0, 0.0, init="normal", init_range=0.1, init_std=0.02, mem_len=None, reuse_len=256, bi_data=False, clamp_len=-1, same_length=False) graph = tf.Graph() with graph.as_default(): self.session = tf.Session() # shape: max_sentence_length x num_sentence self.input_ids = tf.placeholder(tf.int32, [None, None], name="input_ids") # shape: max_sentence_length x num_sentence self.seg_ids = tf.placeholder(tf.int32, [None, None], name="seg_ids") # shape: max_sentence_length x num_sentence self.input_mask = tf.placeholder(tf.int32, [None, None], name="input_mask") # shape: max_sentence_length x max_sentence_length x num_sentence self.perm_mask = tf.placeholder(tf.int32, [None, None, None], name="perm_mask") # shape: max_predictions_per_seq x max_sentence_length x num_sentence self.target_mapping = tf.placeholder(tf.int32, [None, None, None], name="target_mapping") # shape: max_sentence_length x num_sentence self.inp_q = tf.placeholder(tf.int32, [None, None], name="inp_q") # shape: max_sentence_length x num_sentence self.target = tf.placeholder(tf.int32, [None, None], name="target") # # shape: bool scaler # self.is_training = tf.placeholder(tf.bool, name="is_training") xlnet_model = xlnet.XLNetModel(xlnet_config=xlnet_config, run_config=run_config, input_ids=self.input_ids, seg_ids=self.seg_ids, input_mask=self.input_mask, mems=None, perm_mask=self.perm_mask, target_mapping=self.target_mapping, inp_q=self.inp_q) output = xlnet_model.get_sequence_output() lookup_table = xlnet_model.get_embedding_table() initializer = xlnet_model.get_initializer() with tf.variable_scope("model", reuse=tf.AUTO_REUSE): # LM loss lm_loss, logits = modeling.lm_loss( hidden=output, target=self.target, n_token=xlnet_config.n_token, d_model=xlnet_config.d_model, initializer=initializer, lookup_table=lookup_table, tie_weight=True, bi_data=run_config.bi_data, use_tpu=run_config.use_tpu) self.masked_lm_example_loss = run_lm_predict.get_masked_lm_output( self.bert_config, model.get_sequence_output(), model.get_embedding_table(), self.masked_lm_positions, self.masked_lm_ids) # load the pretrained bert model parameters (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint( tf.trainable_variables(), bert_ckpt) tf.train.init_from_checkpoint(bert_ckpt, assignment_map) #### load pretrained models scaffold_fn = model_utils.init_from_checkpoint(FLAGS) self.session.run(tf.global_variables_initializer()) self.tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True)