def compute_loss(log_probs, positions, depth=seq_length): one_hot_positions = tf.one_hot( positions, depth=depth, dtype=tf.float32) loss = - tf.reduce_sum(one_hot_positions * log_probs, axis=-1) loss = tf.reduce_mean(loss) return loss
def classification_loss(hidden, labels, n_class, initializer, scope, reuse=None, return_logits=False): """ Different classification tasks should use different scope names to ensure different dense layers (parameters) are used to produce the logits. An exception will be in transfer learning, where one hopes to transfer the classification weights. """ with tf.variable_scope(scope, reuse=reuse): logits = tf.layers.dense(hidden, n_class, kernel_initializer=initializer, name='logit') one_hot_target = tf.one_hot(labels, n_class, dtype=hidden.dtype) loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1) if return_logits: return loss, logits return loss
def get_race_loss(FLAGS, features, is_training): """Loss for downstream multi-choice QA tasks such as RACE.""" bsz_per_core = tf.shape(features["input_ids"])[0] def _transform_features(feature): out = tf.reshape(feature, [bsz_per_core, 4, -1]) out = tf.transpose(out, [2, 0, 1]) out = tf.reshape(out, [-1, bsz_per_core * 4]) return out inp = _transform_features(features["input_ids"]) seg_id = _transform_features(features["segment_ids"]) inp_mask = _transform_features(features["input_mask"]) label = tf.reshape(features["label_ids"], [bsz_per_core]) xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) run_config = xlnet.create_run_config(is_training, True, FLAGS) xlnet_model = xlnet.XLNetModel( xlnet_config=xlnet_config, run_config=run_config, input_ids=inp, seg_ids=seg_id, input_mask=inp_mask) summary = xlnet_model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj) with tf.variable_scope("logits"): logits = tf.layers.dense(summary, 1, kernel_initializer=xlnet_model.get_initializer()) logits = tf.reshape(logits, [bsz_per_core, 4]) one_hot_target = tf.one_hot(label, 4) per_example_loss = -tf.reduce_sum( tf.nn.log_softmax(logits) * one_hot_target, -1) total_loss = tf.reduce_mean(per_example_loss) return total_loss, per_example_loss, logits
def embedding_lookup(x, n_token, d_embed, initializer, use_tpu=True, scope='embedding', reuse=None, dtype=tf.float32): """TPU and GPU embedding_lookup function.""" with tf.variable_scope(scope, reuse=reuse): lookup_table = tf.get_variable('lookup_table', [n_token, d_embed], dtype=dtype, initializer=initializer) if use_tpu: one_hot_idx = tf.one_hot(x, n_token, dtype=dtype) if one_hot_idx.shape.ndims == 2: return tf.einsum('in,nd->id', one_hot_idx, lookup_table), lookup_table else: return tf.einsum('ibn,nd->ibd', one_hot_idx, lookup_table), lookup_table else: return tf.nn.embedding_lookup(lookup_table, x), lookup_table
def transformer_xl( input_ids, n_token, n_layer, d_model, n_head, d_head, d_inner, dropout, dropatt, attn_type, is_training, initializer, # bi_data, mem_len=None, # inp_q=None, mems=None, same_length=False, clamp_len=-1, untie_r=False, use_tpu=True, input_mask=None, seg_id=None, # perm_mask=None, reuse_len=None, target_mapping=None, ff_activation='relu', use_bfloat16=False, scope='transformer', **kwargs): """ Defines a Transformer-XL computation graph with additional support for XLNet. Args: input_ids: int32 Tensor in shape [len, bsz], the input token IDs. seg_id: int32 Tensor in shape [len, bsz], the input segment IDs. input_mask: float32 Tensor in shape [len, bsz], the input mask. 0 for real tokens and 1 for padding. mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory from previous batches. The length of the list equals n_layer. If None, no memory is used. perm_mask: float32 Tensor in shape [len, len, bsz]. If perm_mask[i, j, k] = 0, i attend to j in batch k; if perm_mask[i, j, k] = 1, i does not attend to j in batch k. If None, each position attends to all the others. target_mapping: float32 Tensor in shape [num_predict, len, bsz]. If target_mapping[i, j, k] = 1, the i-th predict in batch k is on the j-th token. Only used during pretraining for partial prediction. Set to None during finetuning. inp_q: float32 Tensor in shape [len, bsz]. 1 for tokens with losses and 0 for tokens without losses. Only used during pretraining for two-stream attention. Set to None during finetuning. n_layer: int, the number of layers. d_model: int, the hidden size. n_head: int, the number of attention heads. d_head: int, the dimension size of each attention head. d_inner: int, the hidden size in feed-forward layers. ff_activation: str, "relu" or "gelu". untie_r: bool, whether to untie the biases in attention. n_token: int, the vocab size. is_training: bool, whether in training mode. use_tpu: bool, whether TPUs are used. use_bfloat16: bool, use bfloat16 instead of float32. dropout: float, dropout rate. dropatt: float, dropout rate on attention probabilities. init: str, the initialization scheme, either "normal" or "uniform". init_range: float, initialize the parameters with a uniform distribution in [-init_range, init_range]. Only effective when init="uniform". init_std: float, initialize the parameters with a normal distribution with mean 0 and stddev init_std. Only effective when init="normal". mem_len: int, the number of tokens to cache. reuse_len: int, the number of tokens in the currect batch to be cached and reused in the future. bi_data: bool, whether to use bidirectional input pipeline. Usually set to True during pretraining and False during finetuning. clamp_len: int, clamp all relative distances larger than clamp_len. -1 means no clamping. same_length: bool, whether to use the same attention length for each token. summary_type: str, "last", "first", "mean", or "attn". The method to pool the input to get a vector representation. initializer: A tf initializer. scope: scope name for the computation graph. """ # logger.info('memory input {}'.format(mems)) tf_float = tf.bfloat16 if use_bfloat16 else tf.float32 logger.info('Use float type {}'.format(tf_float)) new_mems = [] with tf.variable_scope(scope): if untie_r: r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) else: r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) batch_size = tf.shape(input_ids)[1] seq_len = tf.shape(input_ids)[0] # mlen = tf.shape(mems[0])[0] if mems is not None else 0 mlen = 0 klen = mlen + seq_len # #### Attention mask attn_mask = None # causal attention mask # if attn_type == 'uni': # attn_mask = _create_mask(seq_len, mlen, tf_float, same_length) # attn_mask = attn_mask[:, :, None, None] # elif attn_type == 'bi': # attn_mask = None # else: # raise ValueError('Unsupported attention type: {}'.format(attn_type)) # data mask: input mask & perm mask data_mask = input_mask[None] # if input_mask is not None and perm_mask is not None: # data_mask = input_mask[None] + perm_mask # elif input_mask is not None and perm_mask is None: # data_mask = input_mask[None] # elif input_mask is None and perm_mask is not None: # data_mask = perm_mask # else: # data_mask = None if data_mask is not None: # all mems can be attended to mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, batch_size], dtype=tf_float) data_mask = tf.concat([mems_mask, data_mask], 1) if attn_mask is None: attn_mask = data_mask[:, :, :, None] else: attn_mask += data_mask[:, :, :, None] if attn_mask is not None: attn_mask = tf.cast(attn_mask > 0, dtype=tf_float) if attn_mask is not None: non_tgt_mask = -tf.eye(seq_len, dtype=tf_float) non_tgt_mask = tf.concat( [tf.zeros([seq_len, mlen], dtype=tf_float), non_tgt_mask], axis=-1) non_tgt_mask = tf.cast( (attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=tf_float) else: non_tgt_mask = None # #### Word embedding word_emb_k, lookup_table = embedding_lookup(x=input_ids, n_token=n_token, d_embed=d_model, initializer=initializer, use_tpu=use_tpu, dtype=tf_float, scope='word_embedding') # if inp_q is not None: # with tf.variable_scope('mask_emb'): # mask_emb = tf.get_variable('mask_emb', [1, 1, d_model], # dtype=tf_float) # if target_mapping is not None: # word_emb_q = tf.tile(mask_emb, [tf.shape(target_mapping)[0], # batch_size, 1]) # else: # inp_q_ext = inp_q[:, :, None] # word_emb_q = inp_q_ext * mask_emb + ( # 1 - inp_q_ext) * word_emb_k output_h = tf.layers.dropout(word_emb_k, dropout, training=is_training) # if inp_q is not None: # output_g = tf.layers.dropout(word_emb_q, dropout, # training=is_training) # #### Segment embedding if seg_id is not None: if untie_r: r_s_bias = tf.get_variable('r_s_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) else: # default case (tie) r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) seg_embed = tf.get_variable('seg_embed', [n_layer, 2, n_head, d_head], dtype=tf_float, initializer=initializer) # Convert `seg_id` to one-hot `seg_mat` mem_pad = tf.zeros([mlen, batch_size], dtype=tf.int32) cat_ids = tf.concat([mem_pad, seg_id], 0) # `1` indicates not in the same segment [qlen x klen x bsz] seg_mat = tf.cast( tf.logical_not(tf.equal(seg_id[:, None], cat_ids[None, :])), tf.int32) seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float) else: seg_mat = None # #### Positional encoding pos_emb = relative_positional_encoding(seq_len, klen, d_model, clamp_len, attn_type, bsz=batch_size, dtype=tf_float) pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training) # #### Attention layers # if mems is None: # mems = [None] * n_layer mems = [None] * n_layer for i in range(n_layer): # cache new mems # new_mems.append(_cache_mem(output_h, mems[i], mem_len, reuse_len)) new_mems.append(None) # segment bias if seg_id is None: r_s_bias_i = None seg_embed_i = None else: r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i] seg_embed_i = seg_embed[i] with tf.variable_scope('layer_{}'.format(i)): # if inp_q is not None: # output_h, output_g = two_stream_rel_attn( # h=output_h, # g=output_g, # r=pos_emb, # r_w_bias=r_w_bias if not untie_r else r_w_bias[i], # r_r_bias=r_r_bias if not untie_r else r_r_bias[i], # seg_mat=seg_mat, # r_s_bias=r_s_bias_i, # seg_embed=seg_embed_i, # attn_mask_h=non_tgt_mask, # attn_mask_g=attn_mask, # mems=mems[i], # target_mapping=target_mapping, # d_model=d_model, # n_head=n_head, # d_head=d_head, # dropout=dropout, # dropatt=dropatt, # is_training=is_training, # kernel_initializer=initializer) # reuse = True # else: reuse = False output_h = rel_multihead_attn( h=output_h, r=pos_emb, r_w_bias=r_w_bias if not untie_r else r_w_bias[i], r_r_bias=r_r_bias if not untie_r else r_r_bias[i], seg_mat=seg_mat, r_s_bias=r_s_bias_i, seg_embed=seg_embed_i, attn_mask=non_tgt_mask, mems=mems[i], d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer, reuse=reuse) # if inp_q is not None: # output_g = positionwise_ffn( # inp=output_g, # d_model=d_model, # d_inner=d_inner, # dropout=dropout, # kernel_initializer=initializer, # activation_type=ff_activation, # is_training=is_training) output_h = positionwise_ffn(inp=output_h, d_model=d_model, d_inner=d_inner, dropout=dropout, kernel_initializer=initializer, activation_type=ff_activation, is_training=is_training, reuse=reuse) # if inp_q is not None: # output = tf.layers.dropout(output_g, dropout, training=is_training) # else: # output = tf.layers.dropout(output_h, dropout, training=is_training) output = tf.layers.dropout(output_h, dropout, training=is_training) return output, new_mems, lookup_table
def parser(record): """function used to parse tfrecord.""" record_spec = { "input": tf.FixedLenFeature([seq_len], tf.int64), "target": tf.FixedLenFeature([seq_len], tf.int64), "seg_id": tf.FixedLenFeature([seq_len], tf.int64), "label": tf.FixedLenFeature([1], tf.int64), "is_masked": tf.FixedLenFeature([seq_len], tf.int64), } # retrieve serialized example example = tf.parse_single_example( serialized=record, features=record_spec) inputs = example.pop("input") target = example.pop("target") is_masked = tf.cast(example.pop("is_masked"), tf.bool) non_reuse_len = seq_len - reuse_len assert perm_size <= reuse_len and perm_size <= non_reuse_len perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm( inputs[:reuse_len], target[:reuse_len], is_masked[:reuse_len], perm_size, reuse_len) perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm( inputs[reuse_len:], target[reuse_len:], is_masked[reuse_len:], perm_size, non_reuse_len) perm_mask_0 = tf.concat( [perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1) perm_mask_1 = tf.concat( [tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], axis=1) perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0) target = tf.concat([target_0, target_1], axis=0) target_mask = tf.concat([target_mask_0, target_mask_1], axis=0) input_k = tf.concat([input_k_0, input_k_1], axis=0) input_q = tf.concat([input_q_0, input_q_1], axis=0) if num_predict is not None: indices = tf.range(seq_len, dtype=tf.int64) bool_target_mask = tf.cast(target_mask, tf.bool) indices = tf.boolean_mask(indices, bool_target_mask) ##### extra padding due to CLS/SEP introduced after prepro actual_num_predict = tf.shape(indices)[0] pad_len = num_predict - actual_num_predict ##### target_mapping target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32) paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype) target_mapping = tf.concat([target_mapping, paddings], axis=0) example["target_mapping"] = tf.reshape(target_mapping, [num_predict, seq_len]) ##### target target = tf.boolean_mask(target, bool_target_mask) paddings = tf.zeros([pad_len], dtype=target.dtype) target = tf.concat([target, paddings], axis=0) example["target"] = tf.reshape(target, [num_predict]) ##### target mask target_mask = tf.concat( [tf.ones([actual_num_predict], dtype=tf.float32), tf.zeros([pad_len], dtype=tf.float32)], axis=0) example["target_mask"] = tf.reshape(target_mask, [num_predict]) else: example["target"] = tf.reshape(target, [seq_len]) example["target_mask"] = tf.reshape(target_mask, [seq_len]) # reshape back to fixed shape example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len]) example["input_k"] = tf.reshape(input_k, [seq_len]) example["input_q"] = tf.reshape(input_q, [seq_len]) _convert_example(example, use_bfloat16) for k, v in example.items(): logger.info("%s: %s", k, v) return example
def get_decomposed_qa_outputs(FLAGS, features, is_training): question_ids = features["question_ids"] context_ids = features["context_ids"] seq_len = FLAGS.max_seq_length q_seq_len = FLAGS.max_first_length + 2 ctx_seq_len = seq_len - q_seq_len q_mask_int = tf.cast(tf.cast(question_ids, tf.bool), tf.int32) cls_index = tf.reshape( tf.reduce_sum(q_mask_int, axis=1) + ctx_seq_len, [-1]) # 0 for mask out # q_zeros = tf.zeros_like(question_ids) # p_ids = tf.concat([context_ids, q_zeros], axis=1) # p_mask = tf.cast(tf.cast(p_ids, tf.bool), tf.float32) question_ids = tf.transpose(question_ids, [1, 0]) context_ids = tf.transpose(context_ids, [1, 0]) q_attn_mask = get_attention_mask(question_ids, q_seq_len) c_attn_mask = get_attention_mask(context_ids, ctx_seq_len) qc_attn_mask = get_attention_mask( tf.concat([context_ids, question_ids], axis=0), seq_len) xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) run_config = xlnet.create_run_config(is_training, True, FLAGS) initializer = xlnet._get_initializer(run_config) tfm_args = dict( n_token=xlnet_config.n_token, initializer=initializer, attn_type="bi", n_layer=xlnet_config.n_layer, d_model=xlnet_config.d_model, n_head=xlnet_config.n_head, d_head=xlnet_config.d_head, d_inner=xlnet_config.d_inner, ff_activation=xlnet_config.ff_activation, untie_r=xlnet_config.untie_r, is_training=run_config.is_training, use_bfloat16=run_config.use_bfloat16, use_tpu=run_config.use_tpu, dropout=run_config.dropout, dropatt=run_config.dropatt, # mem_len=run_config.mem_len, # reuse_len=run_config.reuse_len, # bi_data=run_config.bi_data, clamp_len=run_config.clamp_len, # same_length=run_config.same_length, ctx_ids=context_ids, q_ids=question_ids, q_seq_len=q_seq_len, ctx_seq_len=ctx_seq_len, sep_layer=FLAGS.sep_layer, q_attn_mask=q_attn_mask, c_attn_mask=c_attn_mask, qc_attn_mask=qc_attn_mask, ) with tf.variable_scope("model", reuse=tf.AUTO_REUSE): upper_outputs = transformer_xl_decomposed(**tfm_args) output = upper_outputs[-1] return_dict = {'upper_outputs': upper_outputs} with tf.variable_scope("logits"): # logits: seq, batch_size, 2 logits = tf.layers.dense(output, 2, kernel_initializer=initializer) # logits: 2, batch_size, seq logits = tf.transpose(logits, [2, 1, 0]) # start_logits: batch_size, seq # end_logits: batch_size, seq start_logits, end_logits = tf.unstack(logits, axis=0) # start_logits_masked = start_logits * p_mask - 1e30 * (1 - p_mask) # start_log_probs = tf.nn.log_softmax(start_logits_masked, -1) start_log_probs = tf.nn.log_softmax(start_logits, -1) # end_logits_masked = end_logits * p_mask - 1e30 * (1 - p_mask) # end_log_probs = tf.nn.log_softmax(end_logits_masked, -1) end_log_probs = tf.nn.log_softmax(end_logits, -1) return_dict["start_logits"] = start_logits return_dict["end_logits"] = end_logits if is_training: return_dict["start_log_probs"] = start_log_probs return_dict["end_log_probs"] = end_log_probs # an additional layer to predict answer class, 0: span, 1:yes, 2:no with tf.variable_scope("answer_class"): # get the representation of CLS cls_index = tf.one_hot(cls_index, seq_len, axis=-1, dtype=tf.float32) cls_feature = tf.einsum("lbh,bl->bh", output, cls_index) ans_feature = tf.layers.dense(cls_feature, xlnet_config.d_model, activation=tf.tanh, kernel_initializer=initializer, name='pooler') ans_feature = tf.layers.dropout(ans_feature, FLAGS.dropout, training=is_training) # hotpot has 3 classes, # squad 2.0 has 2 classes cls_logits = tf.layers.dense(ans_feature, FLAGS.num_classes, kernel_initializer=initializer, name="cls") cls_log_probs = tf.nn.log_softmax(cls_logits, -1) return_dict["cls_logits"] = cls_logits if is_training: return_dict["cls_log_probs"] = cls_log_probs return return_dict
def transformer_xl_decomposed(n_token, n_layer, d_model, n_head, d_head, d_inner, dropout, dropatt, attn_type, is_training, initializer, q_ids, ctx_ids, clamp_len=-1, untie_r=False, use_tpu=True, ff_activation='relu', use_bfloat16=False, sep_layer=9, q_attn_mask=None, c_attn_mask=None, qc_attn_mask=None, q_seq_len=None, ctx_seq_len=None, scope='transformer', **kwargs): tf_float = tf.bfloat16 if use_bfloat16 else tf.float32 logger.info('Use float type {}'.format(tf_float)) # new_mems = [] with tf.variable_scope(scope): if untie_r: r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) else: r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) # batch_size = tf.shape(input_ids)[1] # seq_len = tf.shape(input_ids)[0] batch_size = tf.shape(q_ids)[1] # mlen = tf.shape(mems[0])[0] if mems is not None else 0 # mlen = 0 # klen = mlen + seq_len # #### Attention mask attn_mask = None # data_mask = input_mask[None] # if data_mask is not None: # all mems can be attended to # mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, batch_size], # dtype=tf_float) # data_mask = tf.concat([mems_mask, data_mask], 1) # if attn_mask is None: # attn_mask = data_mask[:, :, :, None] # else: # attn_mask += data_mask[:, :, :, None] # non_tgt_mask = None # #### Word embedding q_emb, lookup_table = embedding_lookup(x=q_ids, n_token=n_token, d_embed=d_model, initializer=initializer, use_tpu=use_tpu, dtype=tf_float, scope='word_embedding') c_emb, _ = embedding_lookup(x=ctx_ids, n_token=n_token, d_embed=d_model, initializer=initializer, use_tpu=use_tpu, dtype=tf_float, reuse=True, scope='word_embedding') q_output_h = tf.layers.dropout(q_emb, dropout, training=is_training) ctx_output_h = tf.layers.dropout(c_emb, dropout, training=is_training) # #### Segment embedding if untie_r: r_s_bias = tf.get_variable('r_s_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) else: # default case (tie) r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) seg_embed = tf.get_variable('seg_embed', [n_layer, 2, n_head, d_head], dtype=tf_float, initializer=initializer) # Convert `seg_id` to one-hot `seg_mat` # mem_pad = tf.zeros([mlen, batch_size], dtype=tf.int32) # cat_ids = tf.concat([mem_pad, seg_id], 0) # `1` indicates not in the same segment [qlen x klen x bsz] ctx_seg_ids = tf.zeros_like(ctx_ids, dtype=tf.int32) ctx_seg_mat = tf.cast( tf.logical_not(tf.equal(ctx_seg_ids[:, None], ctx_seg_ids[None, :])), tf.int32) ctx_seg_mat = tf.one_hot(ctx_seg_mat, 2, dtype=tf_float) q_seg_ids = tf.ones_like(q_ids, dtype=tf.int32) q_seg_mat = tf.cast( tf.logical_not(tf.equal(q_seg_ids[:, None], q_seg_ids[None, :])), tf.int32) q_seg_mat = tf.one_hot(q_seg_mat, 2, dtype=tf_float) seg_ids = tf.concat([ctx_seg_ids, q_seg_ids], axis=0) seg_mat = tf.cast( tf.logical_not(tf.equal(seg_ids[:, None], seg_ids[None, :])), tf.int32) seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float) # #### Positional encoding FIXME: better use of relative pos emb q_pos_emb = relative_positional_encoding(q_seq_len, q_seq_len, d_model, clamp_len, attn_type, bsz=batch_size, dtype=tf_float) q_pos_emb = tf.layers.dropout(q_pos_emb, dropout, training=is_training) ctx_pos_emb = relative_positional_encoding(ctx_seq_len, ctx_seq_len, d_model, clamp_len, attn_type, bsz=batch_size, dtype=tf_float) ctx_pos_emb = tf.layers.dropout(ctx_pos_emb, dropout, training=is_training) # pos_emb = tf.concat([ctx_pos_emb, q_pos_emb], axis=0) seq_len = ctx_seq_len + q_seq_len pos_emb = relative_positional_encoding(seq_len, seq_len, d_model, clamp_len, attn_type, bsz=batch_size, dtype=tf_float) pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training) # ctx_pos_emb = pos_emb[q_seq_len:q_seq_len + 2 * ctx_seq_len, :, :] # q_pos_emb1 = pos_emb[:q_seq_len, :, :] # q_pos_emb2 = pos_emb[q_seq_len + 2 * ctx_seq_len:, :, :] # q_pos_emb = tf.concat([q_pos_emb1, q_pos_emb2], axis=0) # #### Attention layers # mems = [None] * n_layer for i in range(sep_layer): r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i] r_w_bias_i = r_w_bias if not untie_r else r_w_bias[i] r_r_bias_i = r_r_bias if not untie_r else r_r_bias[i] seg_embed_i = seg_embed[i] with tf.variable_scope('layer_{}'.format(i)): ctx_output_h = rel_multihead_attn( h=ctx_output_h, r=ctx_pos_emb, r_w_bias=r_w_bias_i, r_r_bias=r_r_bias_i, r_s_bias=r_s_bias_i, seg_mat=ctx_seg_mat, seg_embed=seg_embed_i, attn_mask=c_attn_mask, mems=None, d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer, reuse=False) ctx_output_h = positionwise_ffn(inp=ctx_output_h, d_model=d_model, d_inner=d_inner, dropout=dropout, kernel_initializer=initializer, activation_type=ff_activation, is_training=is_training, reuse=False) q_output_h = rel_multihead_attn(h=q_output_h, r=q_pos_emb, r_w_bias=r_w_bias_i, r_r_bias=r_r_bias_i, r_s_bias=r_s_bias_i, seg_mat=q_seg_mat, seg_embed=seg_embed_i, attn_mask=q_attn_mask, mems=None, d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer, reuse=tf.AUTO_REUSE) q_output_h = positionwise_ffn(inp=q_output_h, d_model=d_model, d_inner=d_inner, dropout=dropout, kernel_initializer=initializer, activation_type=ff_activation, is_training=is_training, reuse=tf.AUTO_REUSE) # concat all q, ctx related variables output_h = tf.concat([ctx_output_h, q_output_h], axis=0) upper_outputs = [] for i in range(sep_layer, n_layer): r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i] r_w_bias_i = r_w_bias if not untie_r else r_w_bias[i] r_r_bias_i = r_r_bias if not untie_r else r_r_bias[i] seg_embed_i = seg_embed[i] with tf.variable_scope('layer_{}'.format(i)): output_h = rel_multihead_attn(h=output_h, r=pos_emb, seg_mat=seg_mat, r_w_bias=r_w_bias_i, r_r_bias=r_r_bias_i, r_s_bias=r_s_bias_i, seg_embed=seg_embed_i, attn_mask=qc_attn_mask, mems=None, d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer, reuse=False) output_h = positionwise_ffn(inp=output_h, d_model=d_model, d_inner=d_inner, dropout=dropout, kernel_initializer=initializer, activation_type=ff_activation, is_training=is_training, reuse=False) upper_outputs.append(output_h) output = tf.layers.dropout(output_h, dropout, training=is_training) upper_outputs[-1] = output return upper_outputs
def get_qa_outputs(FLAGS, features, is_training): """Loss for downstream span-extraction QA tasks such as SQuAD.""" input_ids = features["input_ids"] seg_id = features["segment_ids"] input_mask_int = tf.cast(tf.cast(input_ids, tf.bool), tf.int32) cls_index = tf.reshape(tf.reduce_sum(input_mask_int, axis=1), [-1]) p_mask = tf.cast(tf.cast(seg_id, tf.bool), tf.float32) input_ids = tf.transpose(input_ids, [1, 0]) input_mask = 1 - tf.cast(input_mask_int, tf.float32) input_mask = tf.transpose(input_mask, [1, 0]) seg_id = tf.transpose(seg_id, [1, 0]) seq_len = tf.shape(input_ids)[0] xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) run_config = xlnet.create_run_config(is_training, True, FLAGS) xlnet_model = xlnet.XLNetModel( xlnet_config=xlnet_config, run_config=run_config, input_ids=input_ids, seg_ids=seg_id, input_mask=input_mask) output = xlnet_model.get_sequence_output() initializer = xlnet_model.get_initializer() return_dict = {} with tf.variable_scope("logits"): # logits: seq, batch_size, 2 logits = tf.layers.dense(output, 2, kernel_initializer=initializer) # logits: 2, batch_size, seq logits = tf.transpose(logits, [2, 1, 0]) # start_logits: batch_size, seq # end_logits: batch_size, seq start_logits, end_logits = tf.unstack(logits, axis=0) start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask start_log_probs = tf.nn.log_softmax(start_logits_masked, -1) end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask end_log_probs = tf.nn.log_softmax(end_logits_masked, -1) if is_training: return_dict["start_log_probs"] = start_log_probs return_dict["end_log_probs"] = end_log_probs else: return_dict["start_logits"] = start_logits return_dict["end_logits"] = end_logits # an additional layer to predict answer class, 0: span, 1:yes, 2:no with tf.variable_scope("answer_class"): # get the representation of CLS cls_index = tf.one_hot(cls_index, seq_len, axis=-1, dtype=tf.float32) cls_feature = tf.einsum("lbh,bl->bh", output, cls_index) ans_feature = tf.layers.dense(cls_feature, xlnet_config.d_model, activation=tf.tanh, kernel_initializer=initializer, name='pooler') ans_feature = tf.layers.dropout(ans_feature, FLAGS.dropout, training=is_training) # hotpot has 3 classes, # squad 2.0 has 2 classes cls_logits = tf.layers.dense(ans_feature, FLAGS.num_classes, kernel_initializer=initializer, name="cls") cls_log_probs = tf.nn.log_softmax(cls_logits, -1) if is_training: return_dict["cls_log_probs"] = cls_log_probs return_dict["cls_logits"] = cls_logits return return_dict