def rel_attn_core(q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, r_r_bias, r_s_bias, attn_mask, dropatt, is_training, scale): """Core relative positional attention operations.""" # content based attention score ac = tf.einsum('ibnd,jbnd->ijbn', q_head + r_w_bias, k_head_h) # position based attention score bd = tf.einsum('ibnd,jbnd->ijbn', q_head + r_r_bias, k_head_r) bd = rel_shift(bd, klen=tf.shape(ac)[1]) # segment based attention score if seg_mat is None: ef = 0 else: ef = tf.einsum('ibnd,snd->ibns', q_head + r_s_bias, seg_embed) ef = tf.einsum('ijbs,ibns->ijbn', seg_mat, ef) # merge attention scores and perform masking attn_score = (ac + bd + ef) * scale if attn_mask is not None: # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask attn_score = attn_score - 1e30 * attn_mask # attention probability attn_prob = tf.nn.softmax(attn_score, 1) attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training) # attention output attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h) return attn_vec
def post_attention(h, attn_vec, d_model, n_head, d_head, dropout, is_training, kernel_initializer, residual=True): """Post-attention processing.""" # post-attention projection (back to `d_model`) proj_o = tf.get_variable('o/kernel', [d_model, n_head, d_head], dtype=h.dtype, initializer=kernel_initializer) attn_out = tf.einsum('ibnd,hnd->ibh', attn_vec, proj_o) attn_out = tf.layers.dropout(attn_out, dropout, training=is_training) if residual: output = tf.contrib.layers.layer_norm(attn_out + h, begin_norm_axis=-1, scope='LayerNorm') else: output = tf.contrib.layers.layer_norm(attn_out, begin_norm_axis=-1, scope='LayerNorm') return output
def abs_attn_core(q_head, k_head, v_head, attn_mask, dropatt, is_training, scale): """Core absolute positional attention operations.""" attn_score = tf.einsum('ibnd,jbnd->ijbn', q_head, k_head) attn_score *= scale if attn_mask is not None: attn_score = attn_score - 1e30 * attn_mask # attention probability attn_prob = tf.nn.softmax(attn_score, 1) attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training) # attention output attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head) return attn_vec
def head_projection(h, d_model, n_head, d_head, kernel_initializer, name): """Project hidden states to a specific head with a 4D-shape.""" proj_weight = tf.get_variable('{}/kernel'.format(name), [d_model, n_head, d_head], dtype=h.dtype, initializer=kernel_initializer) head = tf.einsum('ibh,hnd->ibnd', h, proj_weight) return head
def positional_embedding(pos_seq, inv_freq, bsz=None): sinusoid_inp = tf.einsum('i,d->id', pos_seq, inv_freq) # outer product pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1) pos_emb = pos_emb[:, None, :] if bsz is not None: pos_emb = tf.tile(pos_emb, [1, bsz, 1]) return pos_emb
def embedding_lookup(x, n_token, d_embed, initializer, use_tpu=True, scope='embedding', reuse=None, dtype=tf.float32): """TPU and GPU embedding_lookup function.""" with tf.variable_scope(scope, reuse=reuse): lookup_table = tf.get_variable('lookup_table', [n_token, d_embed], dtype=dtype, initializer=initializer) if use_tpu: one_hot_idx = tf.one_hot(x, n_token, dtype=dtype) if one_hot_idx.shape.ndims == 2: return tf.einsum('in,nd->id', one_hot_idx, lookup_table), lookup_table else: return tf.einsum('ibn,nd->ibd', one_hot_idx, lookup_table), lookup_table else: return tf.nn.embedding_lookup(lookup_table, x), lookup_table
def get_decomposed_qa_outputs(FLAGS, features, is_training): question_ids = features["question_ids"] context_ids = features["context_ids"] seq_len = FLAGS.max_seq_length q_seq_len = FLAGS.max_first_length + 2 ctx_seq_len = seq_len - q_seq_len q_mask_int = tf.cast(tf.cast(question_ids, tf.bool), tf.int32) cls_index = tf.reshape( tf.reduce_sum(q_mask_int, axis=1) + ctx_seq_len, [-1]) # 0 for mask out # q_zeros = tf.zeros_like(question_ids) # p_ids = tf.concat([context_ids, q_zeros], axis=1) # p_mask = tf.cast(tf.cast(p_ids, tf.bool), tf.float32) question_ids = tf.transpose(question_ids, [1, 0]) context_ids = tf.transpose(context_ids, [1, 0]) q_attn_mask = get_attention_mask(question_ids, q_seq_len) c_attn_mask = get_attention_mask(context_ids, ctx_seq_len) qc_attn_mask = get_attention_mask( tf.concat([context_ids, question_ids], axis=0), seq_len) xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) run_config = xlnet.create_run_config(is_training, True, FLAGS) initializer = xlnet._get_initializer(run_config) tfm_args = dict( n_token=xlnet_config.n_token, initializer=initializer, attn_type="bi", n_layer=xlnet_config.n_layer, d_model=xlnet_config.d_model, n_head=xlnet_config.n_head, d_head=xlnet_config.d_head, d_inner=xlnet_config.d_inner, ff_activation=xlnet_config.ff_activation, untie_r=xlnet_config.untie_r, is_training=run_config.is_training, use_bfloat16=run_config.use_bfloat16, use_tpu=run_config.use_tpu, dropout=run_config.dropout, dropatt=run_config.dropatt, # mem_len=run_config.mem_len, # reuse_len=run_config.reuse_len, # bi_data=run_config.bi_data, clamp_len=run_config.clamp_len, # same_length=run_config.same_length, ctx_ids=context_ids, q_ids=question_ids, q_seq_len=q_seq_len, ctx_seq_len=ctx_seq_len, sep_layer=FLAGS.sep_layer, q_attn_mask=q_attn_mask, c_attn_mask=c_attn_mask, qc_attn_mask=qc_attn_mask, ) with tf.variable_scope("model", reuse=tf.AUTO_REUSE): upper_outputs = transformer_xl_decomposed(**tfm_args) output = upper_outputs[-1] return_dict = {'upper_outputs': upper_outputs} with tf.variable_scope("logits"): # logits: seq, batch_size, 2 logits = tf.layers.dense(output, 2, kernel_initializer=initializer) # logits: 2, batch_size, seq logits = tf.transpose(logits, [2, 1, 0]) # start_logits: batch_size, seq # end_logits: batch_size, seq start_logits, end_logits = tf.unstack(logits, axis=0) # start_logits_masked = start_logits * p_mask - 1e30 * (1 - p_mask) # start_log_probs = tf.nn.log_softmax(start_logits_masked, -1) start_log_probs = tf.nn.log_softmax(start_logits, -1) # end_logits_masked = end_logits * p_mask - 1e30 * (1 - p_mask) # end_log_probs = tf.nn.log_softmax(end_logits_masked, -1) end_log_probs = tf.nn.log_softmax(end_logits, -1) return_dict["start_logits"] = start_logits return_dict["end_logits"] = end_logits if is_training: return_dict["start_log_probs"] = start_log_probs return_dict["end_log_probs"] = end_log_probs # an additional layer to predict answer class, 0: span, 1:yes, 2:no with tf.variable_scope("answer_class"): # get the representation of CLS cls_index = tf.one_hot(cls_index, seq_len, axis=-1, dtype=tf.float32) cls_feature = tf.einsum("lbh,bl->bh", output, cls_index) ans_feature = tf.layers.dense(cls_feature, xlnet_config.d_model, activation=tf.tanh, kernel_initializer=initializer, name='pooler') ans_feature = tf.layers.dropout(ans_feature, FLAGS.dropout, training=is_training) # hotpot has 3 classes, # squad 2.0 has 2 classes cls_logits = tf.layers.dense(ans_feature, FLAGS.num_classes, kernel_initializer=initializer, name="cls") cls_log_probs = tf.nn.log_softmax(cls_logits, -1) return_dict["cls_logits"] = cls_logits if is_training: return_dict["cls_log_probs"] = cls_log_probs return return_dict
def get_qa_outputs(FLAGS, features, is_training): """Loss for downstream span-extraction QA tasks such as SQuAD.""" input_ids = features["input_ids"] seg_id = features["segment_ids"] input_mask_int = tf.cast(tf.cast(input_ids, tf.bool), tf.int32) cls_index = tf.reshape(tf.reduce_sum(input_mask_int, axis=1), [-1]) p_mask = tf.cast(tf.cast(seg_id, tf.bool), tf.float32) input_ids = tf.transpose(input_ids, [1, 0]) input_mask = 1 - tf.cast(input_mask_int, tf.float32) input_mask = tf.transpose(input_mask, [1, 0]) seg_id = tf.transpose(seg_id, [1, 0]) seq_len = tf.shape(input_ids)[0] xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) run_config = xlnet.create_run_config(is_training, True, FLAGS) xlnet_model = xlnet.XLNetModel( xlnet_config=xlnet_config, run_config=run_config, input_ids=input_ids, seg_ids=seg_id, input_mask=input_mask) output = xlnet_model.get_sequence_output() initializer = xlnet_model.get_initializer() return_dict = {} with tf.variable_scope("logits"): # logits: seq, batch_size, 2 logits = tf.layers.dense(output, 2, kernel_initializer=initializer) # logits: 2, batch_size, seq logits = tf.transpose(logits, [2, 1, 0]) # start_logits: batch_size, seq # end_logits: batch_size, seq start_logits, end_logits = tf.unstack(logits, axis=0) start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask start_log_probs = tf.nn.log_softmax(start_logits_masked, -1) end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask end_log_probs = tf.nn.log_softmax(end_logits_masked, -1) if is_training: return_dict["start_log_probs"] = start_log_probs return_dict["end_log_probs"] = end_log_probs else: return_dict["start_logits"] = start_logits return_dict["end_logits"] = end_logits # an additional layer to predict answer class, 0: span, 1:yes, 2:no with tf.variable_scope("answer_class"): # get the representation of CLS cls_index = tf.one_hot(cls_index, seq_len, axis=-1, dtype=tf.float32) cls_feature = tf.einsum("lbh,bl->bh", output, cls_index) ans_feature = tf.layers.dense(cls_feature, xlnet_config.d_model, activation=tf.tanh, kernel_initializer=initializer, name='pooler') ans_feature = tf.layers.dropout(ans_feature, FLAGS.dropout, training=is_training) # hotpot has 3 classes, # squad 2.0 has 2 classes cls_logits = tf.layers.dense(ans_feature, FLAGS.num_classes, kernel_initializer=initializer, name="cls") cls_log_probs = tf.nn.log_softmax(cls_logits, -1) if is_training: return_dict["cls_log_probs"] = cls_log_probs return_dict["cls_logits"] = cls_logits return return_dict