def wrap_rnn(x, cell_type, nlayers, hidden_size, mask=None, bidir=True, use_ln=True, concat=True, dropout=0.0, scope=None): outputs = [x] states = [] if mask is None: xshp = util.shape_list(x) mask = tf.ones([xshp[0], xshp[1]], tf.float32) for layer in range(nlayers): with tf.variable_scope("{}_layer_{}".format(scope or 'rnn', layer)): with tf.variable_scope("fw_rnn"): _, (o_fw, o_fw_s) = rnn.rnn(cell_type, outputs[-1], hidden_size, mask=mask, ln=use_ln, sm=False) if bidir: with tf.variable_scope("bw_rnn"): _, (o_bw, o_bw_s) = rnn.rnn(cell_type, tf.reverse(outputs[-1], [1]), hidden_size, mask=tf.reverse(mask, [1]), ln=use_ln, sm=False) o_bw = tf.reverse(o_bw, [1]) if layer != nlayers - 1: o_fw = util.valid_apply_dropout(o_fw, dropout) o_fw_s = util.valid_apply_dropout(o_fw_s, dropout) if bidir: o_bw = util.valid_apply_dropout(o_bw, dropout) o_bw_s = util.valid_apply_dropout(o_bw_s, dropout) if not bidir: outputs.append(o_fw) states.append(o_fw_s) else: outputs.append(tf.concat([o_fw, o_bw], -1)) states.append(tf.concat([o_fw_s, o_bw_s], -1)) if concat: return tf.concat(outputs[1:], -1), tf.concat(states, -1) else: return outputs[-1], states[-1]
def highway(x, size=None, activation=None, num_layers=2, dropout=0.0, ln=False, scope='highway'): with tf.variable_scope(scope or "highway"): if size is None: size = x.shape.as_list()[-1] else: x = linear(x, size, ln=ln, scope="input_projection") for i in range(num_layers): T = linear(x, size, ln=ln, scope='gate_%d' % i) T = tf.nn.sigmoid(T) H = linear(x, size, ln=ln, scope='activation_%d' % i) if activation is not None: H = activation(H) H = util.valid_apply_dropout(H, dropout) x = H * T + x * (1.0 - T) return x
def ffn_layer(x, d, d_o, dropout=None, scope=None, numblocks=None): """ FFN layer in Transformer :param numblocks: size of 'L' in fixup paper :param scope: """ with tf.variable_scope(scope or "ffn_layer", dtype=tf.as_dtype(dtype.floatx())) as scope: assert numblocks is not None, 'Fixup requires the total model depth L' in_initializer = initializer.scale_initializer( math.pow(numblocks, -1. / 2.), scope.initializer) x = shift_layer(x) hidden = func.linear(x, d, scope="enlarge", weight_initializer=in_initializer, bias=False) hidden = shift_layer(hidden) hidden = tf.nn.relu(hidden) hidden = util.valid_apply_dropout(hidden, dropout) hidden = shift_layer(hidden) output = func.linear(hidden, d_o, scope="output", bias=False, weight_initializer=tf.zeros_initializer()) output = scale_layer(output) return output
def graph(features, params): if params.enable_bert: s = features['s'] bert_input = s sequence_output = bert.bert_encoder(bert_input, params) s_enc = tf.concat(sequence_output[0][-4:], -1)[:, 1:, :] sb = features['sb'] sb_shp = util.shape_list(sb) s_coord = tf.stack([util.batch_coordinates(sb_shp[0], sb_shp[1]), sb], axis=2) s_enc = tf.gather_nd(s_enc, s_coord) features['bert_enc'] = util.valid_apply_dropout(s_enc, params.dropout) if not params.use_bert_single: features['feature'] = s_enc[:, 0, :] else: features['feature'] = sequence_output[1] features = embedding_layer(features, params) features = hierarchy_layer(features, params) graph_output = loss_layer(features, params) return graph_output
def graph(features, params): if params.enable_bert: ps = features['ps'] hs = features['hs'] bert_input = tf.concat([ps, hs], 1) sequence_output = bert.bert_encoder(bert_input, params) sequence_feature = bert.bert_feature(sequence_output[0]) p_len = tf.shape(ps)[1] # 1: remove the encoding for `cls` p_enc = sequence_feature[:, 1:p_len, :] h_enc = sequence_feature[:, p_len:, :] pb = features['pb'] hb = features['hb'] pb_shp = util.shape_list(pb) hb_shp = util.shape_list(hb) p_coord = tf.stack( [util.batch_coordinates(pb_shp[0], pb_shp[1]), pb], axis=2 ) p_enc = tf.gather_nd(p_enc, p_coord) h_coord = tf.stack( [util.batch_coordinates(hb_shp[0], hb_shp[1]), hb], axis=2 ) h_enc = tf.gather_nd(h_enc, h_coord) features['bert_p_enc'] = util.valid_apply_dropout(p_enc, params.dropout) features['bert_h_enc'] = util.valid_apply_dropout(h_enc, params.dropout) if not params.use_bert_single: features['feature'] = sequence_feature[:, 0, :] else: features['feature'] = sequence_output[1] features = embedding_layer(features, params) features = match_layer(features, params) features = loss_layer(features, params) return features
def additive_attention(query, memory, mem_mask, hidden_size, ln=False, proj_memory=None, num_heads=1, dropout=None, att_fun="add", scope=None): """ additive attention model :param query: [batch_size, dim] :param memory: [batch_size, seq_len, mem_dim] :param mem_mask: [batch_size, seq_len] :param hidden_size: attention space dimension :param ln: whether use layer normalization :param proj_memory: this is the mapped memory for saving memory :param num_heads: attention head number :param dropout: attention dropout, default disable :param scope: :return: a value matrix, [batch_size, mem_dim] """ with tf.variable_scope(scope or "additive_attention", dtype=tf.as_dtype(dtype.floatx())): if proj_memory is None: proj_memory = linear(memory, hidden_size, ln=ln, scope="feed_memory") query = linear(tf.expand_dims(query, 1), hidden_size, ln=ln, scope="feed_query") query = split_heads(query, num_heads) proj_memory = split_heads(proj_memory, num_heads) if att_fun == "add": value = tf.tanh(query + proj_memory) logits = linear(value, 1, ln=False, scope="feed_logits") logits = tf.squeeze(logits, -1) else: logits = tf.matmul(query, proj_memory, transpose_b=True) logits = tf.squeeze(logits, 2) logits = util.mask_scale(logits, tf.expand_dims(mem_mask, 1)) weights = tf.nn.softmax(logits, -1) # [batch_size, seq_len] dweights = util.valid_apply_dropout(weights, dropout) memory = split_heads(memory, num_heads) value = tf.reduce_sum( tf.expand_dims(dweights, -1) * memory, -2, keepdims=True) value = combine_heads(value) value = tf.squeeze(value, 1) results = { 'weights': weights, 'output': value, 'cache_state': proj_memory } return results
def ffn_layer(x, d, d_o, dropout=None, scope=None): """FFN layer in Transformer""" with tf.variable_scope(scope or "ffn_layer"): hidden = linear(x, d, scope="enlarge") hidden = tf.nn.relu(hidden) hidden = util.valid_apply_dropout(hidden, dropout) output = linear(hidden, d_o, scope="output") return output
def embedding_layer(features, params): t = features['t'] t_mask = tf.to_float(tf.cast(t, tf.bool)) with tf.device('/cpu:0'): symbol_embeddings = tf.get_variable('special_symbol_embeddings', shape=(3, params.embed_size), trainable=True) embedding_initializer = tf.glorot_uniform_initializer() if params.word_vocab.pretrained_embedding is not None: pretrain_embedding = params.word_vocab.pretrained_embedding embedding_initializer = tf.constant_initializer(pretrain_embedding) general_embeddings = tf.get_variable( 'general_symbol_embeddings', shape=(params.word_vocab.size() - 3, params.embed_size), initializer=embedding_initializer, trainable=params.word_vocab.pretrained_embedding is None) word_embeddings = tf.concat([symbol_embeddings, general_embeddings], 0) # apply word dropout wd_mask = util.valid_apply_dropout(t_mask, params.word_dropout) wd_mask = tf.to_float(tf.cast(wd_mask, tf.bool)) t_emb = tf.nn.embedding_lookup(word_embeddings, t * tf.to_int32(wd_mask)) t_emb = t_emb * tf.expand_dims(t_mask, -1) embed_features = [t_emb] if params.enable_bert: embed_features.append(features['bert_enc']) if params.use_char: c = features['c'] c_mask = tf.to_float(tf.cast(c, tf.bool)) c = tf.reshape(c, [-1, tf.shape(c)[-1]]) c_mask = tf.reshape(c_mask, [-1, tf.shape(c_mask)[-1]]) with tf.device('/cpu:0'): char_embeddings = tf.get_variable( 'char_embeddings', shape=(params.char_vocab.size(), params.char_embed_size), initializer=tf.glorot_uniform_initializer(), trainable=True) with tf.variable_scope('char_embedding'): c_emb = tf.nn.embedding_lookup(char_embeddings, c) c_emb = util.valid_apply_dropout(c_emb, 0.5 * params.dropout) with tf.variable_scope("char_encoding", reuse=tf.AUTO_REUSE): c_emb = c_emb * tf.expand_dims(c_mask, -1) c_shp = util.shape_list(features['c']) c_emb = tf.reshape( c_emb, [c_shp[0], c_shp[1], c_shp[2], params.char_embed_size]) c_state = func.linear(tf.reduce_max(c_emb, 2), params.char_embed_size, scope="cmap") embed_features.append(c_state) t_emb = tf.concat(embed_features, axis=2) * tf.expand_dims(t_mask, -1) features.update({ 't_emb': t_emb, 't_mask': t_mask, }) return features
def residual_fn(x, y, dropout=None): """Residual Connection""" y = util.valid_apply_dropout(y, dropout) return x + y
def dot_attention(query, memory, mem_mask, hidden_size, ln=False, num_heads=1, cache=None, dropout=None, use_relative_pos=False, max_relative_position=16, out_map=True, scope=None, fuse_mask=None, decode_step=None): """ dotted attention model :param query: [batch_size, qey_len, dim] :param memory: [batch_size, seq_len, mem_dim] or None :param mem_mask: [batch_size, seq_len] :param hidden_size: attention space dimension :param ln: whether use layer normalization :param num_heads: attention head number :param dropout: attention dropout, default disable :param out_map: output additional mapping :param cache: cache-based decoding :param fuse_mask: aan mask during training, and timestep for testing :param max_relative_position: maximum position considered for relative embedding :param use_relative_pos: whether use relative position information :param decode_step: the time step of current decoding, 0-based :param scope: :return: a value matrix, [batch_size, qey_len, mem_dim] """ with tf.variable_scope(scope or "dot_attention", reuse=tf.AUTO_REUSE, dtype=tf.as_dtype(dtype.floatx())): if fuse_mask is not None: assert memory is not None, 'Fuse mechanism only applied with cross-attention' if cache and use_relative_pos: assert decode_step is not None, 'Decode Step must provide when use relative position encoding' if memory is None: # suppose self-attention from queries alone h = linear(query, hidden_size * 3, ln=ln, scope="qkv_map") q, k, v = tf.split(h, 3, -1) if cache is not None: k = tf.concat([cache['k'], k], axis=1) v = tf.concat([cache['v'], v], axis=1) cache = { 'k': k, 'v': v, } else: q = linear(query, hidden_size, ln=ln, scope="q_map") if cache is not None and ('mk' in cache and 'mv' in cache): k, v = cache['mk'], cache['mv'] else: k = linear(memory, hidden_size, ln=ln, scope="k_map") v = linear(memory, hidden_size, ln=ln, scope="v_map") if cache is not None: cache['mk'] = k cache['mv'] = v q = split_heads(q, num_heads) k = split_heads(k, num_heads) v = split_heads(v, num_heads) q *= (hidden_size // num_heads) ** (-0.5) q_shp = util.shape_list(q) k_shp = util.shape_list(k) v_shp = util.shape_list(v) q_len = q_shp[2] if decode_step is None else decode_step + 1 r_lst = None if decode_step is None else 1 # q * k => attention weights if use_relative_pos: r = rpr.get_relative_positions_embeddings( q_len, k_shp[2], k_shp[3], max_relative_position, name="rpr_keys", last=r_lst) logits = rpr.relative_attention_inner(q, k, r, transpose=True) else: logits = tf.matmul(q, k, transpose_b=True) if mem_mask is not None: logits += mem_mask weights = tf.nn.softmax(logits) dweights = util.valid_apply_dropout(weights, dropout) # weights * v => attention vectors if use_relative_pos: r = rpr.get_relative_positions_embeddings( q_len, k_shp[2], v_shp[3], max_relative_position, name="rpr_values", last=r_lst) o = rpr.relative_attention_inner(dweights, v, r, transpose=False) else: o = tf.matmul(dweights, v) o = combine_heads(o) if fuse_mask is not None: # This is for AAN, the important part is sharing v_map v_q = linear(query, hidden_size, ln=ln, scope="v_map") if cache is not None and 'aan' in cache: aan_o = (v_q + cache['aan']) / dtype.tf_to_float(fuse_mask + 1) else: # Simplified Average Attention Network aan_o = tf.matmul(fuse_mask, v_q) if cache is not None: if 'aan' not in cache: cache['aan'] = v_q else: cache['aan'] = v_q + cache['aan'] # Directly sum both self-attention and cross attention o = o + aan_o if out_map: o = linear(o, hidden_size, ln=ln, scope="o_map") results = { 'weights': weights, 'output': o, 'cache': cache } return results
def bert_encoder(sequence, params): # extract sequence mask information seq_mask = 1. - tf.to_float(tf.equal(sequence, params.bert.vocab.pad)) # extract segment information seg_pos = tf.to_float(tf.equal(sequence, params.bert.vocab.sep)) seg_ids = tf.cumsum(seg_pos, axis=1, reverse=True) seg_num = tf.reduce_sum(seg_pos, axis=1, keepdims=True) seg_ids = seg_num - seg_ids seg_ids = tf.to_int32(seg_ids * seq_mask) # sequence length information seq_shp = util.shape_list(sequence) batch_size, seq_length = seq_shp[:2] def custom_getter(getter, name, *args, **kwargs): kwargs['trainable'] = params.tune_bert return getter(name, *args, **kwargs) with tf.variable_scope("bert", custom_getter=custom_getter): # handling sequence embeddings: token_embedding pls segment embedding pls positional embedding embed_initializer = tf.truncated_normal_initializer(stddev=params.bert.initializer_range) with tf.variable_scope("embeddings"): word_embedding = tf.get_variable( name="word_embeddings", shape=[params.bert.vocab.size, params.bert.hidden_size], initializer=embed_initializer ) seq_embed = tf.nn.embedding_lookup(word_embedding, sequence) segment_embedding = tf.get_variable( name="token_type_embeddings", shape=[2, params.bert.hidden_size], initializer=embed_initializer ) seg_embed = tf.nn.embedding_lookup(segment_embedding, seg_ids) # word embedding + segment embedding seq_embed = seq_embed + seg_embed # add position embedding assert_op = tf.assert_less_equal(seq_length, params.bert.max_position_embeddings) with tf.control_dependencies([assert_op]): position_embedding = tf.get_variable( name="position_embeddings", shape=[params.bert.max_position_embeddings, params.bert.hidden_size], initializer=embed_initializer ) pos_embed = position_embedding[:seq_length] seq_embed = seq_embed + tf.expand_dims(pos_embed, 0) # post-processing, layer norm and segmentation seq_embed = tc.layers.layer_norm( inputs=seq_embed, begin_norm_axis=-1, begin_params_axis=-1) seq_embed = util.valid_apply_dropout(seq_embed, params.bert.hidden_dropout_prob) bert_outputs = [] # handling sequence encoding with transformer encoder with tf.variable_scope("encoder"): attention_mask = encoder.create_attention_mask_from_input_mask( sequence, seq_mask) # Run the stacked transformer. # `sequence_output` shape = [batch_size, seq_length, hidden_size]. all_encoder_layers = encoder.transformer_model( input_tensor=seq_embed, attention_mask=attention_mask, hidden_size=params.bert.hidden_size, num_hidden_layers=params.bert.num_hidden_layers, num_attention_heads=params.bert.num_attention_heads, intermediate_size=params.bert.intermediate_size, intermediate_act_fn=encoder.get_activation(params.bert.hidden_act), hidden_dropout_prob=params.bert.hidden_dropout_prob, attention_probs_dropout_prob=params.bert.attention_probs_dropout_prob, initializer_range=params.bert.initializer_range, do_return_all_layers=True) sequence_output = all_encoder_layers bert_outputs.append(sequence_output) if params.use_bert_single: # The "pooler" converts the encoded sequence tensor of shape # [batch_size, seq_length, hidden_size] to a tensor of shape # [batch_size, hidden_size]. This is necessary for segment-level # (or segment-pair-level) classification tasks where we need a fixed # dimensional representation of the segment. with tf.variable_scope("pooler"): # We "pool" the model by simply taking the hidden state corresponding # to the first token. We assume that this has been pre-trained first_token_tensor = tf.squeeze(sequence_output[-1][:, 0:1, :], axis=1) pooled_output = tf.layers.dense( first_token_tensor, params.bert.hidden_size, activation=tf.tanh, kernel_initializer=embed_initializer) bert_outputs.append(pooled_output) return bert_outputs
def dot_attention(query, memory, mem_mask, hidden_size, ln=False, num_heads=1, cache=None, dropout=None, out_map=True, scope=None): """ dotted attention model :param query: [batch_size, qey_len, dim] :param memory: [batch_size, seq_len, mem_dim] or None :param mem_mask: [batch_size, seq_len] :param hidden_size: attention space dimension :param ln: whether use layer normalization :param num_heads: attention head number :param dropout: attention dropout, default disable :param out_map: output additional mapping :param cache: cache-based decoding :param scope: :return: a value matrix, [batch_size, qey_len, mem_dim] """ with tf.variable_scope(scope or "dot_attention", reuse=tf.AUTO_REUSE, dtype=tf.as_dtype(dtype.floatx())): if memory is None: # suppose self-attention from queries alone h = func.linear(query, hidden_size * 3, ln=ln, scope="qkv_map") q, k, v = tf.split(h, 3, -1) if cache is not None: k = tf.concat([cache['k'], k], axis=1) v = tf.concat([cache['v'], v], axis=1) cache = { 'k': k, 'v': v, } else: q = func.linear(query, hidden_size, ln=ln, scope="q_map") if cache is not None and ('mk' in cache and 'mv' in cache): k, v = cache['mk'], cache['mv'] else: k = func.linear(memory, hidden_size, ln=ln, scope="k_map") v = func.linear(memory, hidden_size, ln=ln, scope="v_map") if cache is not None: cache['mk'] = k cache['mv'] = v q = func.split_heads(q, num_heads) k = func.split_heads(k, num_heads) v = func.split_heads(v, num_heads) q *= (hidden_size // num_heads) ** (-0.5) # q * k => attention weights logits = tf.matmul(q, k, transpose_b=True) # convert the mask to 0-1 form and multiply to logits if mem_mask is not None: zero_one_mask = tf.to_float(tf.equal(mem_mask, 0.0)) logits *= zero_one_mask # replace softmax with relu # weights = tf.nn.softmax(logits) weights = tf.nn.relu(logits) dweights = util.valid_apply_dropout(weights, dropout) # weights * v => attention vectors o = tf.matmul(dweights, v) o = func.combine_heads(o) # perform RMSNorm to stabilize running o = gated_rms_norm(o, scope="post") if out_map: o = func.linear(o, hidden_size, ln=ln, scope="o_map") results = { 'weights': weights, 'output': o, 'cache': cache } return results
def decoder(target, state, params): mask = dtype.tf_to_float(tf.cast(target, tf.bool)) hidden_size = params.hidden_size is_training = ('decoder' not in state) # handling target-side word embedding, including shift-padding for training embed_name = "embedding" if params.shared_source_target_embedding \ else "tgt_embedding" tgt_emb = tf.get_variable(embed_name, [params.tgt_vocab.size(), params.embed_size]) tgt_bias = tf.get_variable("bias", [params.embed_size]) inputs = tf.gather(tgt_emb, target) inputs = tf.nn.bias_add(inputs, tgt_bias) # shift if is_training: inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]]) inputs = inputs[:, :-1, :] else: inputs = tf.cond( tf.reduce_all(tf.equal(target, params.tgt_vocab.pad())), lambda: tf.zeros_like(inputs), lambda: inputs) mask = tf.ones_like(mask) inputs = util.valid_apply_dropout(inputs, params.dropout) with tf.variable_scope("decoder"): x = inputs init_state = state["decoder_initializer"]["layer"] if not is_training: init_state = state["decoder"]["state"]["layer"] returns = deep_att_dec_rnn(params.cell, x, state["encodes"], hidden_size, init_state=init_state, mask=mask, num_heads=params.num_heads, mem_mask=state["mask"], ln=params.layer_norm, sm=params.swap_memory, depth=params.num_decoder_layer) (_, hidden_state), (outputs, _), contexts, attentions = returns if not is_training: state['decoder']['state']['layer'] = hidden_state x = outputs cshp = util.shape_list(contexts) c = tf.reshape(contexts, [cshp[0], cshp[1], cshp[2] * cshp[3]]) feature = func.linear(tf.concat([x, c, inputs], -1), params.embed_size, ln=params.layer_norm, scope="ff") feature = tf.nn.tanh(feature) feature = util.valid_apply_dropout(feature, params.dropout) if 'dev_decode' in state: feature = feature[:, -1, :] embed_name = "tgt_embedding" if params.shared_target_softmax_embedding \ else "softmax_embedding" embed_name = "embedding" if params.shared_source_target_embedding \ else embed_name softmax_emb = tf.get_variable(embed_name, [params.tgt_vocab.size(), params.embed_size]) feature = tf.reshape(feature, [-1, params.embed_size]) logits = tf.matmul(feature, softmax_emb, False, True) logits = tf.cast(logits, tf.float32) soft_label, normalizer = util.label_smooth(target, util.shape_list(logits)[-1], factor=params.label_smooth) centropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=soft_label) centropy -= normalizer centropy = tf.reshape(centropy, tf.shape(target)) mask = tf.cast(mask, tf.float32) per_sample_loss = tf.reduce_sum(centropy * mask, -1) / tf.reduce_sum( mask, -1) loss = tf.reduce_mean(per_sample_loss) # these mask tricks mainly used to deal with zero shapes, such as [0, 1] loss = tf.cond(tf.equal(tf.shape(target)[0], 0), lambda: tf.constant(0, dtype=tf.float32), lambda: loss) return loss, logits, state, per_sample_loss
def decoder(target, state, params): mask = dtype.tf_to_float(tf.cast(target, tf.bool)) hidden_size = params.hidden_size initializer = tf.random_normal_initializer(0.0, hidden_size**-0.5) is_training = ('decoder' not in state) if is_training: target, mask = util.remove_invalid_seq(target, mask) embed_name = "embedding" if params.shared_source_target_embedding \ else "tgt_embedding" tgt_emb = tf.get_variable(embed_name, [params.tgt_vocab.size(), params.embed_size], initializer=initializer) tgt_bias = tf.get_variable("bias", [params.embed_size]) inputs = tf.gather(tgt_emb, target) * (hidden_size**0.5) inputs = tf.nn.bias_add(inputs, tgt_bias) # shift if is_training: inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]]) inputs = inputs[:, :-1, :] inputs = func.add_timing_signal(inputs) else: inputs = tf.cond( tf.reduce_all(tf.equal(target, params.tgt_vocab.pad())), lambda: tf.zeros_like(inputs), lambda: inputs) mask = tf.ones_like(mask) inputs = func.add_timing_signal(inputs, time=dtype.tf_to_float(state['time'])) inputs = util.valid_apply_dropout(inputs, params.dropout) with tf.variable_scope("decoder"): x = inputs for layer in range(params.num_decoder_layer): if params.deep_transformer_init: layer_initializer = tf.variance_scaling_initializer( params.initializer_gain * (layer + 1)**-0.5, mode="fan_avg", distribution="uniform") else: layer_initializer = None with tf.variable_scope("layer_{}".format(layer), initializer=layer_initializer): with tf.variable_scope("average_attention"): x_fwds = [] for strategy in params.strategies: with tf.variable_scope(strategy): x_fwd = average_attention_strategy( strategy, x, mask, state, layer, params) x_fwds.append(x_fwd) x_fwd = tf.add_n(x_fwds) / len(x_fwds) # FFN activation if params.use_ffn: y = func.ffn_layer( x_fwd, params.filter_size, hidden_size, dropout=params.relu_dropout, ) else: y = x_fwd # Gating layer z = func.linear(tf.concat([x, y], axis=-1), hidden_size * 2, scope="z_project") i, f = tf.split(z, 2, axis=-1) y = tf.sigmoid(i) * x + tf.sigmoid(f) * y x = func.residual_fn(x, y, dropout=params.residual_dropout) x = func.layer_norm(x) with tf.variable_scope("cross_attention"): y = func.dot_attention( x, state['encodes'], func.attention_bias(state['mask'], "masking"), hidden_size, num_heads=params.num_heads, dropout=params.attention_dropout, cache=None if is_training else state['decoder']['state']['layer_{}'.format(layer)]) if not is_training: # mk, mv state['decoder']['state']['layer_{}'.format(layer)]\ .update(y['cache']) y = y['output'] x = func.residual_fn(x, y, dropout=params.residual_dropout) x = func.layer_norm(x) with tf.variable_scope("feed_forward"): y = func.ffn_layer( x, params.filter_size, hidden_size, dropout=params.relu_dropout, ) x = func.residual_fn(x, y, dropout=params.residual_dropout) x = func.layer_norm(x) feature = x if 'dev_decode' in state: feature = x[:, -1, :] embed_name = "tgt_embedding" if params.shared_target_softmax_embedding \ else "softmax_embedding" embed_name = "embedding" if params.shared_source_target_embedding \ else embed_name softmax_emb = tf.get_variable(embed_name, [params.tgt_vocab.size(), params.embed_size], initializer=initializer) feature = tf.reshape(feature, [-1, params.embed_size]) logits = tf.matmul(feature, softmax_emb, False, True) logits = tf.cast(logits, tf.float32) soft_label, normalizer = util.label_smooth(target, util.shape_list(logits)[-1], factor=params.label_smooth) centropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=soft_label) centropy -= normalizer centropy = tf.reshape(centropy, tf.shape(target)) mask = tf.cast(mask, tf.float32) per_sample_loss = tf.reduce_sum(centropy * mask, -1) / tf.reduce_sum( mask, -1) loss = tf.reduce_mean(per_sample_loss) # these mask tricks mainly used to deal with zero shapes, such as [0, 1] loss = tf.cond(tf.equal(tf.shape(target)[0], 0), lambda: tf.constant(0, dtype=tf.float32), lambda: loss) return loss, logits, state, per_sample_loss
def dot_attention(query, memory, mem_mask, hidden_size, ln=False, num_heads=1, cache=None, dropout=None, use_relative_pos=True, max_relative_position=16, out_map=True, scope=None): """ dotted attention model :param query: [batch_size, qey_len, dim] :param memory: [batch_size, seq_len, mem_dim] or None :param mem_mask: [batch_size, seq_len] :param hidden_size: attention space dimension :param ln: whether use layer normalization :param num_heads: attention head number :param dropout: attention dropout, default disable :param out_map: output additional mapping :param cache: cache-based decoding :param max_relative_position: maximum position considered for relative embedding :param use_relative_pos: whether use relative position information :param scope: :return: a value matrix, [batch_size, qey_len, mem_dim] """ with tf.variable_scope(scope or "dot_attention"): if memory is None: # suppose self-attention from queries alone h = linear(query, hidden_size * 3, ln=ln, scope="qkv_map") q, k, v = tf.split(h, 3, -1) if cache is not None: k = tf.concat([cache['k'], k], axis=1) v = tf.concat([cache['v'], v], axis=1) cache = { 'k': k, 'v': v, } else: q = linear(query, hidden_size, ln=ln, scope="q_map") if cache is not None and ('mk' in cache and 'mv' in cache): k, v = cache['mk'], cache['mv'] else: h = linear(memory, hidden_size * 2, ln=ln, scope="kv_map") k, v = tf.split(h, 2, -1) if cache is not None: cache['mk'] = k cache['mv'] = v q = split_heads(q, num_heads) k = split_heads(k, num_heads) v = split_heads(v, num_heads) q *= (hidden_size // num_heads)**(-0.5) q_shp = util.shape_list(q) k_shp = util.shape_list(k) v_shp = util.shape_list(v) # q * k => attention weights if use_relative_pos: r = get_relative_positions_embeddings( q_shp[2], k_shp[2], k_shp[3], max_relative_position, name="relative_positions_keys") logits = relative_attention_inner(q, k, r, transpose=True) else: logits = tf.matmul(q, k, transpose_b=True) if mem_mask is not None: logits += mem_mask weights = tf.nn.softmax(logits) weights = util.valid_apply_dropout(weights, dropout) # weights * v => attention vectors if use_relative_pos: r = get_relative_positions_embeddings( q_shp[2], k_shp[2], v_shp[3], max_relative_position, name="relative_positions_values") o = relative_attention_inner(weights, v, r, transpose=False) else: o = tf.matmul(weights, v) o = combine_heads(o) if out_map: o = linear(o, hidden_size, ln=ln, scope="o_map") results = {'weights': weights, 'output': o, 'cache': cache} return results
def decoder(target, state, params): mask = dtype.tf_to_float(tf.cast(target, tf.bool)) hidden_size = params.hidden_size is_training = ('decoder' not in state) if is_training: target, mask = util.remove_invalid_seq(target, mask) embed_name = "embedding" if params.shared_source_target_embedding \ else "tgt_embedding" tgt_emb = tf.get_variable(embed_name, [params.tgt_vocab.size(), params.embed_size]) tgt_bias = tf.get_variable("bias", [params.embed_size]) inputs = tf.gather(tgt_emb, target) inputs = tf.nn.bias_add(inputs, tgt_bias) # shift if is_training: inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]]) inputs = inputs[:, :-1, :] else: inputs = tf.cond( tf.reduce_all(tf.equal(target, params.tgt_vocab.pad())), lambda: tf.zeros_like(inputs), lambda: inputs) mask = tf.ones_like(mask) inputs = util.valid_apply_dropout(inputs, params.dropout) with tf.variable_scope("decoder"): x = inputs for layer in range(params.num_decoder_layer): with tf.variable_scope("layer_{}".format(layer)): init_state = state["decoder_initializer"]["layer_{}".format( layer)] if not is_training: init_state = state["decoder"]["state"]["layer_{}".format( layer)] if layer == 0 or params.use_deep_att: returns = rnn.cond_rnn(params.cell, x, state["encodes"], hidden_size, init_state=init_state, mask=mask, num_heads=params.num_heads, mem_mask=state["mask"], ln=params.layer_norm, sm=params.swap_memory, one2one=False) (_, hidden_state), (outputs, _), contexts, attentions = returns c = contexts else: if params.caencoder: returns = rnn.cond_rnn(params.cell, x, c, hidden_size, init_state=init_state, mask=mask, mem_mask=mask, ln=params.layer_norm, sm=params.swap_memory, num_heads=params.num_heads, one2one=True) (_, hidden_state), (outputs, _), contexts, attentions = returns else: outputs = rnn.rnn(params.cell, tf.concat([x, c], -1), hidden_size, mask=mask, init_state=init_state, ln=params.layer_norm, sm=params.swap_memory) outputs, hidden_state = outputs[1] if not is_training: state['decoder']['state']['layer_{}'.format( layer)] = hidden_state y = func.linear(outputs, params.embed_size, ln=False, scope="ff") # short cut via residual connection if x.get_shape()[-1].value == y.get_shape()[-1].value: x = func.residual_fn(x, y, dropout=params.dropout) else: x = y if params.layer_norm: x = func.layer_norm(x, scope="ln") if params.dl4mt_redict: feature = func.linear(tf.concat([x, c], -1), params.embed_size, ln=params.layer_norm, scope="ff") feature = tf.nn.tanh(feature) feature = util.valid_apply_dropout(feature, params.dropout) else: feature = x if 'dev_decode' in state: feature = x[:, -1, :] embed_name = "tgt_embedding" if params.shared_target_softmax_embedding \ else "softmax_embedding" embed_name = "embedding" if params.shared_source_target_embedding \ else embed_name softmax_emb = tf.get_variable(embed_name, [params.tgt_vocab.size(), params.embed_size]) feature = tf.reshape(feature, [-1, params.embed_size]) logits = tf.matmul(feature, softmax_emb, False, True) logits = tf.cast(logits, tf.float32) soft_label, normalizer = util.label_smooth(target, util.shape_list(logits)[-1], factor=params.label_smooth) centropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=soft_label) centropy -= normalizer centropy = tf.reshape(centropy, tf.shape(target)) mask = tf.cast(mask, tf.float32) per_sample_loss = tf.reduce_sum(centropy * mask, -1) / tf.reduce_sum( mask, -1) loss = tf.reduce_mean(per_sample_loss) # these mask tricks mainly used to deal with zero shapes, such as [0, 1] loss = tf.cond(tf.equal(tf.shape(target)[0], 0), lambda: tf.constant(0, dtype=tf.float32), lambda: loss) return loss, logits, state, per_sample_loss
def decoder(target, state, params): mask = dtype.tf_to_float(tf.cast(target, tf.bool)) hidden_size = params.hidden_size initializer = tf.random_normal_initializer(0.0, hidden_size**-0.5) is_training = ('decoder' not in state) if is_training: target, mask = util.remove_invalid_seq(target, mask) embed_name = "embedding" if params.shared_source_target_embedding \ else "tgt_embedding" tgt_emb = tf.get_variable(embed_name, [params.tgt_vocab.size(), params.embed_size], initializer=initializer) tgt_bias = tf.get_variable("bias", [params.embed_size]) inputs = tf.gather(tgt_emb, target) * (hidden_size**0.5) inputs = tf.nn.bias_add(inputs, tgt_bias) # shift if is_training: inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]]) inputs = inputs[:, :-1, :] inputs = func.add_timing_signal(inputs) else: inputs = tf.cond( tf.reduce_all(tf.equal(target, params.tgt_vocab.pad())), lambda: tf.zeros_like(inputs), lambda: inputs) mask = tf.ones_like(mask) inputs = func.add_timing_signal(inputs, time=dtype.tf_to_float(state['time'])) inputs = util.valid_apply_dropout(inputs, params.dropout) # Applying L0Drop # -------- source_memory = state["encodes"] source_mask = state["mask"] # source_pruning: log alpha_i = x_i w^T source_pruning = func.linear(source_memory, 1, scope="source_pruning") if is_training: # training source_memory, l0_mask = l0norm.var_train( (source_memory, source_pruning)) l0_norm_loss = tf.squeeze(l0norm.l0_norm(source_pruning), -1) l0_norm_loss = tf.reduce_sum(l0_norm_loss * source_mask, -1) / tf.reduce_sum(source_mask, -1) l0_norm_loss = tf.reduce_mean(l0_norm_loss) l0_norm_loss = l0norm.l0_regularization_loss( l0_norm_loss, reg_scalar=params.l0_norm_reg_scalar, start_reg_ramp_up=params.l0_norm_start_reg_ramp_up, end_reg_ramp_up=params.l0_norm_end_reg_ramp_up, warm_up=params.l0_norm_warm_up, ) # force the model to only attend to unmasked position source_mask = dtype.tf_to_float( tf.cast(tf.squeeze(l0_mask, -1), tf.bool)) * source_mask else: # evaluation source_memory, l0_mask = l0norm.var_eval( (source_memory, source_pruning)) l0_norm_loss = 0.0 source_memory, source_mask, count_mask = extract_encodes( source_memory, source_mask, l0_mask) count_mask = tf.expand_dims(tf.expand_dims(count_mask, 1), 1) # -------- with tf.variable_scope("decoder"): x = inputs for layer in range(params.num_decoder_layer): if params.deep_transformer_init: layer_initializer = tf.variance_scaling_initializer( params.initializer_gain * (layer + 1)**-0.5, mode="fan_avg", distribution="uniform") else: layer_initializer = None with tf.variable_scope("layer_{}".format(layer), initializer=layer_initializer): with tf.variable_scope("self_attention"): y = func.dot_attention( x, None, func.attention_bias(tf.shape(mask)[1], "causal"), hidden_size, num_heads=params.num_heads, dropout=params.attention_dropout, cache=None if is_training else state['decoder']['state']['layer_{}'.format(layer)]) if not is_training: # k, v state['decoder']['state']['layer_{}'.format(layer)] \ .update(y['cache']) y = y['output'] x = func.residual_fn(x, y, dropout=params.residual_dropout) x = func.layer_norm(x) with tf.variable_scope("cross_attention"): if is_training: y = func.dot_attention( x, source_memory, func.attention_bias(source_mask, "masking"), hidden_size, num_heads=params.num_heads, dropout=params.attention_dropout, ) else: y = dot_attention(x, source_memory, func.attention_bias( source_mask, "masking"), hidden_size, count_mask=count_mask, num_heads=params.num_heads, dropout=params.attention_dropout, cache=state['decoder']['state'][ 'layer_{}'.format(layer)]) # mk, mv state['decoder']['state']['layer_{}'.format(layer)] \ .update(y['cache']) y = y['output'] x = func.residual_fn(x, y, dropout=params.residual_dropout) x = func.layer_norm(x) with tf.variable_scope("feed_forward"): y = func.ffn_layer( x, params.filter_size, hidden_size, dropout=params.relu_dropout, ) x = func.residual_fn(x, y, dropout=params.residual_dropout) x = func.layer_norm(x) feature = x if 'dev_decode' in state: feature = x[:, -1, :] embed_name = "tgt_embedding" if params.shared_target_softmax_embedding \ else "softmax_embedding" embed_name = "embedding" if params.shared_source_target_embedding \ else embed_name softmax_emb = tf.get_variable(embed_name, [params.tgt_vocab.size(), params.embed_size], initializer=initializer) feature = tf.reshape(feature, [-1, params.embed_size]) logits = tf.matmul(feature, softmax_emb, False, True) logits = tf.cast(logits, tf.float32) soft_label, normalizer = util.label_smooth(target, util.shape_list(logits)[-1], factor=params.label_smooth) centropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=soft_label) centropy -= normalizer centropy = tf.reshape(centropy, tf.shape(target)) mask = tf.cast(mask, tf.float32) per_sample_loss = tf.reduce_sum(centropy * mask, -1) / tf.reduce_sum( mask, -1) loss = tf.reduce_mean(per_sample_loss) loss = loss + l0_norm_loss # these mask tricks mainly used to deal with zero shapes, such as [0, 1] loss = tf.cond(tf.equal(tf.shape(target)[0], 0), lambda: tf.constant(0, tf.float32), lambda: loss) return loss, logits, state, per_sample_loss
def dot_attention(query, memory, mem_mask, hidden_size, ln=False, num_heads=1, cache=None, dropout=None, out_map=True, scope=None, count_mask=None): """ dotted attention model with l0drop :param query: [batch_size, qey_len, dim] :param memory: [batch_size, seq_len, mem_dim] or None :param mem_mask: [batch_size, seq_len] :param hidden_size: attention space dimension :param ln: whether use layer normalization :param num_heads: attention head number :param dropout: attention dropout, default disable :param out_map: output additional mapping :param cache: cache-based decoding :param count_mask: counting vector for l0drop :param scope: :return: a value matrix, [batch_size, qey_len, mem_dim] """ with tf.variable_scope(scope or "dot_attention", reuse=tf.AUTO_REUSE, dtype=tf.as_dtype(dtype.floatx())): if memory is None: # suppose self-attention from queries alone h = func.linear(query, hidden_size * 3, ln=ln, scope="qkv_map") q, k, v = tf.split(h, 3, -1) if cache is not None: k = tf.concat([cache['k'], k], axis=1) v = tf.concat([cache['v'], v], axis=1) cache = { 'k': k, 'v': v, } else: q = func.linear(query, hidden_size, ln=ln, scope="q_map") if cache is not None and ('mk' in cache and 'mv' in cache): k, v = cache['mk'], cache['mv'] else: k = func.linear(memory, hidden_size, ln=ln, scope="k_map") v = func.linear(memory, hidden_size, ln=ln, scope="v_map") if cache is not None: cache['mk'] = k cache['mv'] = v q = func.split_heads(q, num_heads) k = func.split_heads(k, num_heads) v = func.split_heads(v, num_heads) q *= (hidden_size // num_heads)**(-0.5) # q * k => attention weights logits = tf.matmul(q, k, transpose_b=True) if mem_mask is not None: logits += mem_mask # modifying 'weights = tf.nn.softmax(logits)' to include the counting information. # -------- logits = logits - tf.reduce_max(logits, -1, keepdims=True) exp_logits = tf.exp(logits) # basically, the count considers how many states are dropped (i.e. gate value 0s) if count_mask is not None: exp_logits *= count_mask exp_sum_logits = tf.reduce_sum(exp_logits, -1, keepdims=True) weights = exp_logits / exp_sum_logits # -------- dweights = util.valid_apply_dropout(weights, dropout) # weights * v => attention vectors o = tf.matmul(dweights, v) o = func.combine_heads(o) if out_map: o = func.linear(o, hidden_size, ln=ln, scope="o_map") results = {'weights': weights, 'output': o, 'cache': cache} return results
def encoder(source, params): mask = dtype.tf_to_float(tf.cast(source, tf.bool)) hidden_size = params.hidden_size source, mask = util.remove_invalid_seq(source, mask) # extract source word embedding and apply dropout embed_name = "embedding" if params.shared_source_target_embedding \ else "src_embedding" src_emb = tf.get_variable(embed_name, [params.src_vocab.size(), params.embed_size]) src_bias = tf.get_variable("bias", [params.embed_size]) inputs = tf.gather(src_emb, source) inputs = tf.nn.bias_add(inputs, src_bias) inputs = util.valid_apply_dropout(inputs, params.dropout) # the encoder module used in the deep attention paper with tf.variable_scope("encoder"): # x: embedding input, h: the hidden state x = inputs h = 0 z = 0 for layer in range(params.num_encoder_layer + 1): with tf.variable_scope("layer_{}".format(layer)): if layer == 0: # for the first layer, we perform a normal rnn layer to collect context information outputs = rnn.rnn(params.cell, x, hidden_size, mask=mask, ln=params.layer_norm, sm=params.swap_memory) h = outputs[1][0] else: # for deeper encoder layers, we incorporate both embedding input and previous inversed hidden # state sequence as input. # the embedding informs current input while hidden state tells future context is_reverse = (layer % 2 == 1) outputs = rnn.cond_rnn( params.cell, tf.reverse(x, [1]) if is_reverse else x, tf.reverse(h, [1]) if is_reverse else h, hidden_size, mask=tf.reverse(mask, [1]) if is_reverse else mask, ln=params.layer_norm, sm=params.swap_memory, num_heads=params.num_heads, one2one=True) h = outputs[1][0] h = tf.reverse(h, [1]) if is_reverse else h # the final hidden state used for decoder state initialization z = outputs[1][1] with tf.variable_scope("decoder_initializer"): decoder_cell = rnn.get_cell(params.cell, hidden_size, ln=params.layer_norm) return { "encodes": h, "decoder_initializer": { 'layer': decoder_cell.get_init_state(x=z, scope="dec_init_state") }, "mask": mask }
def encoder(source, params): mask = dtype.tf_to_float(tf.cast(source, tf.bool)) hidden_size = params.hidden_size initializer = tf.random_normal_initializer(0.0, hidden_size**-0.5) source, mask = util.remove_invalid_seq(source, mask) embed_name = "embedding" if params.shared_source_target_embedding \ else "src_embedding" src_emb = tf.get_variable(embed_name, [params.src_vocab.size(), params.embed_size], initializer=initializer) src_bias = tf.get_variable("bias", [params.embed_size]) inputs = tf.gather(src_emb, source) * (hidden_size**0.5) inputs = tf.nn.bias_add(inputs, src_bias) inputs = func.add_timing_signal(inputs) inputs = util.valid_apply_dropout(inputs, params.dropout) with tf.variable_scope("encoder"): x = inputs for layer in range(params.num_encoder_layer): if params.deep_transformer_init: layer_initializer = tf.variance_scaling_initializer( params.initializer_gain * (layer + 1)**-0.5, mode="fan_avg", distribution="uniform") else: layer_initializer = None with tf.variable_scope("layer_{}".format(layer), initializer=layer_initializer): with tf.variable_scope("self_attention"): y = func.dot_attention(x, None, func.attention_bias( mask, "masking"), hidden_size, num_heads=params.num_heads, dropout=params.attention_dropout) y = y['output'] x = func.residual_fn(x, y, dropout=params.residual_dropout) x = func.layer_norm(x) with tf.variable_scope("feed_forward"): y = func.ffn_layer( x, params.filter_size, hidden_size, dropout=params.relu_dropout, ) x = func.residual_fn(x, y, dropout=params.residual_dropout) x = func.layer_norm(x) source_encodes = x x_shp = util.shape_list(x) return { "encodes": source_encodes, "decoder_initializer": { "layer_{}".format(l): { # plan aan "aan": dtype.tf_to_float(tf.zeros([x_shp[0], 1, hidden_size])), } for l in range(params.num_decoder_layer) }, "mask": mask }
def encoder(source, params): mask = dtype.tf_to_float(tf.cast(source, tf.bool)) hidden_size = params.hidden_size source, mask = util.remove_invalid_seq(source, mask) embed_name = "embedding" if params.shared_source_target_embedding \ else "src_embedding" src_emb = tf.get_variable(embed_name, [params.src_vocab.size(), params.embed_size]) src_bias = tf.get_variable("bias", [params.embed_size]) inputs = tf.gather(src_emb, source) inputs = tf.nn.bias_add(inputs, src_bias) inputs = util.valid_apply_dropout(inputs, params.dropout) with tf.variable_scope("encoder"): # forward rnn with tf.variable_scope('forward'): outputs = rnn.rnn(params.cell, inputs, hidden_size, mask=mask, ln=params.layer_norm, sm=params.swap_memory) output_fw, state_fw = outputs[1] # backward rnn with tf.variable_scope('backward'): if not params.caencoder: outputs = rnn.rnn(params.cell, tf.reverse(inputs, [1]), hidden_size, mask=tf.reverse(mask, [1]), ln=params.layer_norm, sm=params.swap_memory) output_bw, state_bw = outputs[1] else: outputs = rnn.cond_rnn(params.cell, tf.reverse(inputs, [1]), tf.reverse(output_fw, [1]), hidden_size, mask=tf.reverse(mask, [1]), ln=params.layer_norm, sm=params.swap_memory, num_heads=params.num_heads, one2one=True) output_bw, state_bw = outputs[1] output_bw = tf.reverse(output_bw, [1]) if not params.caencoder: source_encodes = tf.concat([output_fw, output_bw], -1) source_feature = tf.concat([state_fw, state_bw], -1) else: source_encodes = output_bw source_feature = state_bw with tf.variable_scope("decoder_initializer"): decoder_init = rnn.get_cell( params.cell, hidden_size, ln=params.layer_norm).get_init_state(x=source_feature) decoder_init = tf.tanh(decoder_init) return { "encodes": source_encodes, "decoder_initializer": decoder_init, "mask": mask }
def encoder(source, params): mask = dtype.tf_to_float(tf.cast(source, tf.bool)) hidden_size = params.hidden_size source, mask = util.remove_invalid_seq(source, mask) embed_name = "embedding" if params.shared_source_target_embedding \ else "src_embedding" src_emb = tf.get_variable(embed_name, [params.src_vocab.size(), params.embed_size]) src_bias = tf.get_variable("bias", [params.embed_size]) inputs = tf.gather(src_emb, source) inputs = tf.nn.bias_add(inputs, src_bias) inputs = util.valid_apply_dropout(inputs, params.dropout) with tf.variable_scope("encoder"): x = inputs for layer in range(params.num_encoder_layer): with tf.variable_scope("layer_{}".format(layer)): # forward rnn with tf.variable_scope('forward'): outputs = rnn.rnn(params.cell, x, hidden_size, mask=mask, ln=params.layer_norm, sm=params.swap_memory) output_fw, state_fw = outputs[1] if layer == 0: # backward rnn with tf.variable_scope('backward'): if not params.caencoder: outputs = rnn.rnn(params.cell, tf.reverse(x, [1]), hidden_size, mask=tf.reverse(mask, [1]), ln=params.layer_norm, sm=params.swap_memory) output_bw, state_bw = outputs[1] else: outputs = rnn.cond_rnn(params.cell, tf.reverse(x, [1]), tf.reverse(output_fw, [1]), hidden_size, mask=tf.reverse(mask, [1]), ln=params.layer_norm, sm=params.swap_memory, num_heads=params.num_heads, one2one=True) output_bw, state_bw = outputs[1] output_bw = tf.reverse(output_bw, [1]) if not params.caencoder: y = tf.concat([output_fw, output_bw], -1) z = tf.concat([state_fw, state_bw], -1) else: y = output_bw z = state_bw else: y = output_fw z = state_fw y = func.linear(y, params.embed_size, ln=False, scope="ff") # short cut via residual connection if x.get_shape()[-1].value == y.get_shape()[-1].value: x = func.residual_fn(x, y, dropout=params.dropout) else: x = y if params.layer_norm: x = func.layer_norm(x, scope="ln") if params.embed_size != hidden_size: x = func.layer_norm(func.linear(x, hidden_size, scope="x_map")) with tf.variable_scope("decoder_initializer"): decoder_cell = rnn.get_cell(params.cell, hidden_size, ln=params.layer_norm) return { "encodes": x, "decoder_initializer": { "layer_{}".format(l): decoder_cell.get_init_state(x=z, scope="layer_{}".format(l)) for l in range(params.num_decoder_layer) }, "mask": mask }