def apply_gradients(self, grads_and_vars, global_step=None, name=None): """See base class.""" assignments = [] for (grad, param) in grads_and_vars: if grad is None or param is None: continue param_name = self._get_variable_name(param.name) m = tf.get_variable(name=param_name + "/adam_m", shape=param.shape.as_list(), dtype=tf.float32, trainable=False, initializer=tf.zeros_initializer()) v = tf.get_variable(name=param_name + "/adam_v", shape=param.shape.as_list(), dtype=tf.float32, trainable=False, initializer=tf.zeros_initializer()) # Standard Adam update. next_m = (tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) next_v = (tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, tf.square(grad))) update = next_m / (tf.sqrt(next_v) + self.epsilon) # Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want ot decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # of the weights to the loss with plain (non-momentum) SGD. if self._do_use_weight_decay(param_name): update += self.weight_decay_rate * param update_with_lr = self.learning_rate * update next_param = param - update_with_lr assignments.extend( [param.assign(next_param), m.assign(next_m), v.assign(next_v)]) return tf.group(*assignments, name=name)
def post_attention(h, attn_vec, d_model, n_head, d_head, dropout, is_training, kernel_initializer, residual=True): """Post-attention processing.""" # post-attention projection (back to `d_model`) proj_o = tf.get_variable('o/kernel', [d_model, n_head, d_head], dtype=h.dtype, initializer=kernel_initializer) attn_out = tf.einsum('ibnd,hnd->ibh', attn_vec, proj_o) attn_out = tf.layers.dropout(attn_out, dropout, training=is_training) if residual: output = tf.contrib.layers.layer_norm(attn_out + h, begin_norm_axis=-1, scope='LayerNorm') else: output = tf.contrib.layers.layer_norm(attn_out, begin_norm_axis=-1, scope='LayerNorm') return output
def head_projection(h, d_model, n_head, d_head, kernel_initializer, name): """Project hidden states to a specific head with a 4D-shape.""" proj_weight = tf.get_variable('{}/kernel'.format(name), [d_model, n_head, d_head], dtype=h.dtype, initializer=kernel_initializer) head = tf.einsum('ibh,hnd->ibnd', h, proj_weight) return head
def clean_ckpt(_): input_ckpt = FLAGS.clean_input_ckpt output_model_dir = FLAGS.clean_output_model_dir tf.reset_default_graph() var_list = tf.contrib.framework.list_variables(input_ckpt) var_values, var_dtypes = {}, {} for (name, shape) in var_list: if not name.startswith("global_step") and "adam" not in name.lower(): var_values[name] = None logger.info("Include {}".format(name)) else: logger.info("Exclude {}".format(name)) logger.info("Loading from {}".format(input_ckpt)) reader = tf.contrib.framework.load_checkpoint(input_ckpt) for name in var_values: tensor = reader.get_tensor(name) var_dtypes[name] = tensor.dtype var_values[name] = tensor with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): tf_vars = [ tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v]) for v in var_values ] placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] global_step = tf.Variable(0, name="global_step", trainable=False, dtype=tf.int64) saver = tf.train.Saver(tf.all_variables()) if not tf.gfile.Exists(output_model_dir): tf.gfile.MakeDirs(output_model_dir) # Build a model consisting only of variables, set them to the average values. with tf.Session() as sess: sess.run(tf.initialize_all_variables()) for p, assign_op, (name, value) in zip(placeholders, assign_ops, six.iteritems(var_values)): sess.run(assign_op, {p: value}) # Use the built saver to save the averaged checkpoint. saver.save(sess, join(output_model_dir, "model.ckpt"), global_step=global_step)
def avg_checkpoints(model_dir, output_model_dir, last_k): tf.reset_default_graph() checkpoint_state = tf.train.get_checkpoint_state(model_dir) checkpoints = checkpoint_state.all_model_checkpoint_paths[-last_k:] var_list = tf.contrib.framework.list_variables(checkpoints[0]) var_values, var_dtypes = {}, {} for (name, shape) in var_list: if not name.startswith("global_step"): var_values[name] = np.zeros(shape) for checkpoint in checkpoints: reader = tf.contrib.framework.load_checkpoint(checkpoint) for name in var_values: tensor = reader.get_tensor(name) var_dtypes[name] = tensor.dtype var_values[name] += tensor logger.info("Read from checkpoint %s", checkpoint) for name in var_values: # Average. var_values[name] /= len(checkpoints) with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): tf_vars = [ tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v]) for v in var_values ] placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] global_step = tf.Variable(0, name="global_step", trainable=False, dtype=tf.int64) saver = tf.train.Saver(tf.all_variables()) # Build a model consisting only of variables, set them to the average values. with tf.Session() as sess: sess.run(tf.initialize_all_variables()) for p, assign_op, (name, value) in zip(placeholders, assign_ops, six.iteritems(var_values)): sess.run(assign_op, {p: value}) # Use the built saver to save the averaged checkpoint. saver.save(sess, join(output_model_dir, "model.ckpt"), global_step=global_step)
def embedding_lookup(x, n_token, d_embed, initializer, use_tpu=True, scope='embedding', reuse=None, dtype=tf.float32): """TPU and GPU embedding_lookup function.""" with tf.variable_scope(scope, reuse=reuse): lookup_table = tf.get_variable('lookup_table', [n_token, d_embed], dtype=dtype, initializer=initializer) if use_tpu: one_hot_idx = tf.one_hot(x, n_token, dtype=dtype) if one_hot_idx.shape.ndims == 2: return tf.einsum('in,nd->id', one_hot_idx, lookup_table), lookup_table else: return tf.einsum('ibn,nd->ibd', one_hot_idx, lookup_table), lookup_table else: return tf.nn.embedding_lookup(lookup_table, x), lookup_table
def summarize_sequence(summary_type, hidden, d_model, n_head, d_head, dropout, dropatt, input_mask, is_training, initializer, scope=None, reuse=None, use_proj=True): """ Different classification tasks may not may not share the same parameters to summarize the sequence features. If shared, one can keep the `scope` to the default value `None`. Otherwise, one should specify a different `scope` for each task. """ with tf.variable_scope(scope, 'sequnece_summary', reuse=reuse): if summary_type == 'last': summary = hidden[-1] elif summary_type == 'first': summary = hidden[0] elif summary_type == 'mean': summary = tf.reduce_mean(hidden, axis=0) elif summary_type == 'attn': bsz = tf.shape(hidden)[1] summary_bias = tf.get_variable('summary_bias', [d_model], dtype=hidden.dtype, initializer=initializer) summary_bias = tf.tile(summary_bias[None, None], [1, bsz, 1]) if input_mask is not None: input_mask = input_mask[None, :, :, None] summary = multihead_attn(summary_bias, hidden, hidden, input_mask, d_model, n_head, d_head, dropout, dropatt, is_training, initializer, residual=False) summary = summary[0] else: raise ValueError( 'Unsupported summary type {}'.format(summary_type)) # use another projection as in BERT if use_proj: summary = tf.layers.dense(summary, d_model, activation=tf.tanh, kernel_initializer=initializer, name='summary') # dropout summary = tf.layers.dropout(summary, dropout, training=is_training, name='dropout') return summary
def transformer_xl( input_ids, n_token, n_layer, d_model, n_head, d_head, d_inner, dropout, dropatt, attn_type, is_training, initializer, # bi_data, mem_len=None, # inp_q=None, mems=None, same_length=False, clamp_len=-1, untie_r=False, use_tpu=True, input_mask=None, seg_id=None, # perm_mask=None, reuse_len=None, target_mapping=None, ff_activation='relu', use_bfloat16=False, scope='transformer', **kwargs): """ Defines a Transformer-XL computation graph with additional support for XLNet. Args: input_ids: int32 Tensor in shape [len, bsz], the input token IDs. seg_id: int32 Tensor in shape [len, bsz], the input segment IDs. input_mask: float32 Tensor in shape [len, bsz], the input mask. 0 for real tokens and 1 for padding. mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory from previous batches. The length of the list equals n_layer. If None, no memory is used. perm_mask: float32 Tensor in shape [len, len, bsz]. If perm_mask[i, j, k] = 0, i attend to j in batch k; if perm_mask[i, j, k] = 1, i does not attend to j in batch k. If None, each position attends to all the others. target_mapping: float32 Tensor in shape [num_predict, len, bsz]. If target_mapping[i, j, k] = 1, the i-th predict in batch k is on the j-th token. Only used during pretraining for partial prediction. Set to None during finetuning. inp_q: float32 Tensor in shape [len, bsz]. 1 for tokens with losses and 0 for tokens without losses. Only used during pretraining for two-stream attention. Set to None during finetuning. n_layer: int, the number of layers. d_model: int, the hidden size. n_head: int, the number of attention heads. d_head: int, the dimension size of each attention head. d_inner: int, the hidden size in feed-forward layers. ff_activation: str, "relu" or "gelu". untie_r: bool, whether to untie the biases in attention. n_token: int, the vocab size. is_training: bool, whether in training mode. use_tpu: bool, whether TPUs are used. use_bfloat16: bool, use bfloat16 instead of float32. dropout: float, dropout rate. dropatt: float, dropout rate on attention probabilities. init: str, the initialization scheme, either "normal" or "uniform". init_range: float, initialize the parameters with a uniform distribution in [-init_range, init_range]. Only effective when init="uniform". init_std: float, initialize the parameters with a normal distribution with mean 0 and stddev init_std. Only effective when init="normal". mem_len: int, the number of tokens to cache. reuse_len: int, the number of tokens in the currect batch to be cached and reused in the future. bi_data: bool, whether to use bidirectional input pipeline. Usually set to True during pretraining and False during finetuning. clamp_len: int, clamp all relative distances larger than clamp_len. -1 means no clamping. same_length: bool, whether to use the same attention length for each token. summary_type: str, "last", "first", "mean", or "attn". The method to pool the input to get a vector representation. initializer: A tf initializer. scope: scope name for the computation graph. """ # logger.info('memory input {}'.format(mems)) tf_float = tf.bfloat16 if use_bfloat16 else tf.float32 logger.info('Use float type {}'.format(tf_float)) new_mems = [] with tf.variable_scope(scope): if untie_r: r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) else: r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) batch_size = tf.shape(input_ids)[1] seq_len = tf.shape(input_ids)[0] # mlen = tf.shape(mems[0])[0] if mems is not None else 0 mlen = 0 klen = mlen + seq_len # #### Attention mask attn_mask = None # causal attention mask # if attn_type == 'uni': # attn_mask = _create_mask(seq_len, mlen, tf_float, same_length) # attn_mask = attn_mask[:, :, None, None] # elif attn_type == 'bi': # attn_mask = None # else: # raise ValueError('Unsupported attention type: {}'.format(attn_type)) # data mask: input mask & perm mask data_mask = input_mask[None] # if input_mask is not None and perm_mask is not None: # data_mask = input_mask[None] + perm_mask # elif input_mask is not None and perm_mask is None: # data_mask = input_mask[None] # elif input_mask is None and perm_mask is not None: # data_mask = perm_mask # else: # data_mask = None if data_mask is not None: # all mems can be attended to mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, batch_size], dtype=tf_float) data_mask = tf.concat([mems_mask, data_mask], 1) if attn_mask is None: attn_mask = data_mask[:, :, :, None] else: attn_mask += data_mask[:, :, :, None] if attn_mask is not None: attn_mask = tf.cast(attn_mask > 0, dtype=tf_float) if attn_mask is not None: non_tgt_mask = -tf.eye(seq_len, dtype=tf_float) non_tgt_mask = tf.concat( [tf.zeros([seq_len, mlen], dtype=tf_float), non_tgt_mask], axis=-1) non_tgt_mask = tf.cast( (attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=tf_float) else: non_tgt_mask = None # #### Word embedding word_emb_k, lookup_table = embedding_lookup(x=input_ids, n_token=n_token, d_embed=d_model, initializer=initializer, use_tpu=use_tpu, dtype=tf_float, scope='word_embedding') # if inp_q is not None: # with tf.variable_scope('mask_emb'): # mask_emb = tf.get_variable('mask_emb', [1, 1, d_model], # dtype=tf_float) # if target_mapping is not None: # word_emb_q = tf.tile(mask_emb, [tf.shape(target_mapping)[0], # batch_size, 1]) # else: # inp_q_ext = inp_q[:, :, None] # word_emb_q = inp_q_ext * mask_emb + ( # 1 - inp_q_ext) * word_emb_k output_h = tf.layers.dropout(word_emb_k, dropout, training=is_training) # if inp_q is not None: # output_g = tf.layers.dropout(word_emb_q, dropout, # training=is_training) # #### Segment embedding if seg_id is not None: if untie_r: r_s_bias = tf.get_variable('r_s_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) else: # default case (tie) r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) seg_embed = tf.get_variable('seg_embed', [n_layer, 2, n_head, d_head], dtype=tf_float, initializer=initializer) # Convert `seg_id` to one-hot `seg_mat` mem_pad = tf.zeros([mlen, batch_size], dtype=tf.int32) cat_ids = tf.concat([mem_pad, seg_id], 0) # `1` indicates not in the same segment [qlen x klen x bsz] seg_mat = tf.cast( tf.logical_not(tf.equal(seg_id[:, None], cat_ids[None, :])), tf.int32) seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float) else: seg_mat = None # #### Positional encoding pos_emb = relative_positional_encoding(seq_len, klen, d_model, clamp_len, attn_type, bsz=batch_size, dtype=tf_float) pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training) # #### Attention layers # if mems is None: # mems = [None] * n_layer mems = [None] * n_layer for i in range(n_layer): # cache new mems # new_mems.append(_cache_mem(output_h, mems[i], mem_len, reuse_len)) new_mems.append(None) # segment bias if seg_id is None: r_s_bias_i = None seg_embed_i = None else: r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i] seg_embed_i = seg_embed[i] with tf.variable_scope('layer_{}'.format(i)): # if inp_q is not None: # output_h, output_g = two_stream_rel_attn( # h=output_h, # g=output_g, # r=pos_emb, # r_w_bias=r_w_bias if not untie_r else r_w_bias[i], # r_r_bias=r_r_bias if not untie_r else r_r_bias[i], # seg_mat=seg_mat, # r_s_bias=r_s_bias_i, # seg_embed=seg_embed_i, # attn_mask_h=non_tgt_mask, # attn_mask_g=attn_mask, # mems=mems[i], # target_mapping=target_mapping, # d_model=d_model, # n_head=n_head, # d_head=d_head, # dropout=dropout, # dropatt=dropatt, # is_training=is_training, # kernel_initializer=initializer) # reuse = True # else: reuse = False output_h = rel_multihead_attn( h=output_h, r=pos_emb, r_w_bias=r_w_bias if not untie_r else r_w_bias[i], r_r_bias=r_r_bias if not untie_r else r_r_bias[i], seg_mat=seg_mat, r_s_bias=r_s_bias_i, seg_embed=seg_embed_i, attn_mask=non_tgt_mask, mems=mems[i], d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer, reuse=reuse) # if inp_q is not None: # output_g = positionwise_ffn( # inp=output_g, # d_model=d_model, # d_inner=d_inner, # dropout=dropout, # kernel_initializer=initializer, # activation_type=ff_activation, # is_training=is_training) output_h = positionwise_ffn(inp=output_h, d_model=d_model, d_inner=d_inner, dropout=dropout, kernel_initializer=initializer, activation_type=ff_activation, is_training=is_training, reuse=reuse) # if inp_q is not None: # output = tf.layers.dropout(output_g, dropout, training=is_training) # else: # output = tf.layers.dropout(output_h, dropout, training=is_training) output = tf.layers.dropout(output_h, dropout, training=is_training) return output, new_mems, lookup_table
def transformer_xl_decomposed(n_token, n_layer, d_model, n_head, d_head, d_inner, dropout, dropatt, attn_type, is_training, initializer, q_ids, ctx_ids, clamp_len=-1, untie_r=False, use_tpu=True, ff_activation='relu', use_bfloat16=False, sep_layer=9, q_attn_mask=None, c_attn_mask=None, qc_attn_mask=None, q_seq_len=None, ctx_seq_len=None, scope='transformer', **kwargs): tf_float = tf.bfloat16 if use_bfloat16 else tf.float32 logger.info('Use float type {}'.format(tf_float)) # new_mems = [] with tf.variable_scope(scope): if untie_r: r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) else: r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) # batch_size = tf.shape(input_ids)[1] # seq_len = tf.shape(input_ids)[0] batch_size = tf.shape(q_ids)[1] # mlen = tf.shape(mems[0])[0] if mems is not None else 0 # mlen = 0 # klen = mlen + seq_len # #### Attention mask attn_mask = None # data_mask = input_mask[None] # if data_mask is not None: # all mems can be attended to # mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, batch_size], # dtype=tf_float) # data_mask = tf.concat([mems_mask, data_mask], 1) # if attn_mask is None: # attn_mask = data_mask[:, :, :, None] # else: # attn_mask += data_mask[:, :, :, None] # non_tgt_mask = None # #### Word embedding q_emb, lookup_table = embedding_lookup(x=q_ids, n_token=n_token, d_embed=d_model, initializer=initializer, use_tpu=use_tpu, dtype=tf_float, scope='word_embedding') c_emb, _ = embedding_lookup(x=ctx_ids, n_token=n_token, d_embed=d_model, initializer=initializer, use_tpu=use_tpu, dtype=tf_float, reuse=True, scope='word_embedding') q_output_h = tf.layers.dropout(q_emb, dropout, training=is_training) ctx_output_h = tf.layers.dropout(c_emb, dropout, training=is_training) # #### Segment embedding if untie_r: r_s_bias = tf.get_variable('r_s_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) else: # default case (tie) r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) seg_embed = tf.get_variable('seg_embed', [n_layer, 2, n_head, d_head], dtype=tf_float, initializer=initializer) # Convert `seg_id` to one-hot `seg_mat` # mem_pad = tf.zeros([mlen, batch_size], dtype=tf.int32) # cat_ids = tf.concat([mem_pad, seg_id], 0) # `1` indicates not in the same segment [qlen x klen x bsz] ctx_seg_ids = tf.zeros_like(ctx_ids, dtype=tf.int32) ctx_seg_mat = tf.cast( tf.logical_not(tf.equal(ctx_seg_ids[:, None], ctx_seg_ids[None, :])), tf.int32) ctx_seg_mat = tf.one_hot(ctx_seg_mat, 2, dtype=tf_float) q_seg_ids = tf.ones_like(q_ids, dtype=tf.int32) q_seg_mat = tf.cast( tf.logical_not(tf.equal(q_seg_ids[:, None], q_seg_ids[None, :])), tf.int32) q_seg_mat = tf.one_hot(q_seg_mat, 2, dtype=tf_float) seg_ids = tf.concat([ctx_seg_ids, q_seg_ids], axis=0) seg_mat = tf.cast( tf.logical_not(tf.equal(seg_ids[:, None], seg_ids[None, :])), tf.int32) seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float) # #### Positional encoding FIXME: better use of relative pos emb q_pos_emb = relative_positional_encoding(q_seq_len, q_seq_len, d_model, clamp_len, attn_type, bsz=batch_size, dtype=tf_float) q_pos_emb = tf.layers.dropout(q_pos_emb, dropout, training=is_training) ctx_pos_emb = relative_positional_encoding(ctx_seq_len, ctx_seq_len, d_model, clamp_len, attn_type, bsz=batch_size, dtype=tf_float) ctx_pos_emb = tf.layers.dropout(ctx_pos_emb, dropout, training=is_training) # pos_emb = tf.concat([ctx_pos_emb, q_pos_emb], axis=0) seq_len = ctx_seq_len + q_seq_len pos_emb = relative_positional_encoding(seq_len, seq_len, d_model, clamp_len, attn_type, bsz=batch_size, dtype=tf_float) pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training) # ctx_pos_emb = pos_emb[q_seq_len:q_seq_len + 2 * ctx_seq_len, :, :] # q_pos_emb1 = pos_emb[:q_seq_len, :, :] # q_pos_emb2 = pos_emb[q_seq_len + 2 * ctx_seq_len:, :, :] # q_pos_emb = tf.concat([q_pos_emb1, q_pos_emb2], axis=0) # #### Attention layers # mems = [None] * n_layer for i in range(sep_layer): r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i] r_w_bias_i = r_w_bias if not untie_r else r_w_bias[i] r_r_bias_i = r_r_bias if not untie_r else r_r_bias[i] seg_embed_i = seg_embed[i] with tf.variable_scope('layer_{}'.format(i)): ctx_output_h = rel_multihead_attn( h=ctx_output_h, r=ctx_pos_emb, r_w_bias=r_w_bias_i, r_r_bias=r_r_bias_i, r_s_bias=r_s_bias_i, seg_mat=ctx_seg_mat, seg_embed=seg_embed_i, attn_mask=c_attn_mask, mems=None, d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer, reuse=False) ctx_output_h = positionwise_ffn(inp=ctx_output_h, d_model=d_model, d_inner=d_inner, dropout=dropout, kernel_initializer=initializer, activation_type=ff_activation, is_training=is_training, reuse=False) q_output_h = rel_multihead_attn(h=q_output_h, r=q_pos_emb, r_w_bias=r_w_bias_i, r_r_bias=r_r_bias_i, r_s_bias=r_s_bias_i, seg_mat=q_seg_mat, seg_embed=seg_embed_i, attn_mask=q_attn_mask, mems=None, d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer, reuse=tf.AUTO_REUSE) q_output_h = positionwise_ffn(inp=q_output_h, d_model=d_model, d_inner=d_inner, dropout=dropout, kernel_initializer=initializer, activation_type=ff_activation, is_training=is_training, reuse=tf.AUTO_REUSE) # concat all q, ctx related variables output_h = tf.concat([ctx_output_h, q_output_h], axis=0) upper_outputs = [] for i in range(sep_layer, n_layer): r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i] r_w_bias_i = r_w_bias if not untie_r else r_w_bias[i] r_r_bias_i = r_r_bias if not untie_r else r_r_bias[i] seg_embed_i = seg_embed[i] with tf.variable_scope('layer_{}'.format(i)): output_h = rel_multihead_attn(h=output_h, r=pos_emb, seg_mat=seg_mat, r_w_bias=r_w_bias_i, r_r_bias=r_r_bias_i, r_s_bias=r_s_bias_i, seg_embed=seg_embed_i, attn_mask=qc_attn_mask, mems=None, d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer, reuse=False) output_h = positionwise_ffn(inp=output_h, d_model=d_model, d_inner=d_inner, dropout=dropout, kernel_initializer=initializer, activation_type=ff_activation, is_training=is_training, reuse=False) upper_outputs.append(output_h) output = tf.layers.dropout(output_h, dropout, training=is_training) upper_outputs[-1] = output return upper_outputs