def average_attention_strategy(strategy, x, mask, state, layer, params): strategy = strategy.lower() is_training = ('decoder' not in state) if strategy == "aan": if is_training: if params.aan_mask: aan_bias = func.attention_bias(mask, "aan") x_fwd = tf.matmul(aan_bias, x) else: aan_bias = tf.cumsum(mask, axis=1) aan_bias = tf.where(tf.less_equal(aan_bias, 0.), tf.ones_like(aan_bias), aan_bias) aan_bias = tf.expand_dims(dtype.tf_to_float(aan_bias), 2) x_fwd = tf.cumsum(x, axis=1) / aan_bias else: cache = state['decoder']['state']['layer_{}'.format(layer)] x_fwd = (x + cache['aan']) / dtype.tf_to_float(state['time'] + 1) cache['aan'] = x + cache['aan'] return x_fwd else: raise NotImplementedError("Not supported: {}".format(strategy))
def add_timing_signal(x, min_timescale=1.0, max_timescale=1.0e4, time=None, name=None): """Transformer Positional Embedding""" with tf.name_scope(name, default_name="add_timing_signal", values=[x]): length = tf.shape(x)[1] channels = tf.shape(x)[2] if time is None: position = dtype.tf_to_float(tf.range(length)) else: # decoding position embedding position = tf.expand_dims(time, 0) num_timescales = channels // 2 log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / (dtype.tf_to_float(num_timescales) - 1) ) inv_timescales = min_timescale * tf.exp( dtype.tf_to_float(tf.range(num_timescales)) * -log_timescale_increment ) scaled_time = (tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0)) signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]]) signal = tf.reshape(signal, [1, length, channels]) return x + signal
def extract_encodes(source_memory, source_mask, l0_mask): x_shp = util.shape_list(source_memory) l0_mask = dtype.tf_to_float(tf.cast(l0_mask, tf.bool)) l0_mask = tf.squeeze(l0_mask, -1) * source_mask # count retained encodings k_value = tf.cast(tf.reduce_max(tf.reduce_sum(l0_mask, 1)), tf.int32) # batch_size x k_value _, topk_indices = tf.nn.top_k(l0_mask, k_value) # prepare coordinate x_pos = util.batch_coordinates(x_shp[0], k_value) coord = tf.stack([x_pos, topk_indices], axis=2) # gather retained features g_x = tf.gather_nd(source_memory, coord) g_mask = tf.gather_nd(l0_mask, coord) # padding zero g_x = tf.pad(g_x, [[0, 0], [1, 0], [0, 0]]) # generate counts, i.e. how many tokens are dropped droped_number = tf.reduce_sum(source_mask, 1) - tf.reduce_sum(l0_mask, 1) pad_mask = dtype.tf_to_float(tf.greater(droped_number, 0.)) droped_number = tf.where(tf.less_equal(droped_number, 0.), tf.ones_like(droped_number), droped_number) count_mask = tf.ones_like(g_mask) count_mask = tf.concat([tf.expand_dims(droped_number, 1), count_mask], 1) g_mask = tf.concat([tf.expand_dims(pad_mask, 1), g_mask], 1) return g_x, g_mask, count_mask
def attention_bias(inputs, mode, inf=None, name=None): """ A bias tensor used in attention mechanism""" if inf is None: inf = dtype.inf() with tf.name_scope(name, default_name="attention_bias", values=[inputs]): if mode == "causal": length = inputs lower_triangle = tf.matrix_band_part( tf.ones([length, length]), -1, 0 ) ret = dtype.tf_to_float(- inf * (1.0 - lower_triangle)) return tf.reshape(ret, [1, 1, length, length]) elif mode == "masking": mask = inputs ret = (1.0 - mask) * - inf return tf.expand_dims(tf.expand_dims(ret, 1), 1) elif mode == "aan": length = tf.shape(inputs)[1] diagonal = tf.eye(length) cum_factor = tf.expand_dims(tf.cumsum(diagonal, axis=0), 0) mask = tf.expand_dims(inputs, 1) * tf.expand_dims(inputs, 2) mask *= dtype.tf_to_float(cum_factor) weight = tf.nn.softmax(mask + (1.0 - mask) * - inf) weight *= mask return weight else: raise ValueError("Unknown mode %s" % mode)
def _step_fn(prev, x): t, h_, c_, a_ = prev if not one2one: m, v = x[mask_pos], x[:mask_pos] else: c, c_c, m, v = x[-1], x[mask_pos + 1:-1], x[mask_pos], x[:mask_pos] s = cell_lower(h_, v) s = m * s + (1. - m) * h_ if not one2one: vle = additive_attention(cell_lower.get_hidden(s), memory, mem_mask, mem_shape[-1], ln=ln, num_heads=num_heads, proj_memory=proj_memories, scope="attention") a, c = vle['weights'], vle['output'] c_c = cell_higher.fetch_states(c) else: a = tf.tile(tf.expand_dims(tf.range(time_steps), 0), [batch_size, 1]) a = dtype.tf_to_float(tf.equal(a, t)) a = tf.tile(tf.expand_dims(a, 1), [1, num_heads, 1]) a = tf.reshape(a, tf.shape(init_weight)) h = cell_higher(s, c_c) h = m * h + (1. - m) * s return t + 1, h, c, a
def rnn(cell_name, x, d, mask=None, ln=False, init_state=None, sm=True): """Self implemented RNN procedure, supporting mask trick""" # cell_name: gru, lstm or atr # x: input sequence embedding matrix, [batch, seq_len, dim] # d: hidden dimension for rnn # mask: mask matrix, [batch, seq_len] # ln: whether use layer normalization # init_state: the initial hidden states, for cache purpose # sm: whether apply swap memory during rnn scan # dp: variational dropout in_shape = util.shape_list(x) batch_size, time_steps = in_shape[:2] cell = get_cell(cell_name, d, ln=ln) if init_state is None: init_state = cell.get_init_state(shape=[batch_size]) if mask is None: mask = dtype.tf_to_float(tf.ones([batch_size, time_steps])) # prepare projected input cache_inputs = cell.fetch_states(x) cache_inputs = [tf.transpose(v, [1, 0, 2]) for v in list(cache_inputs)] mask_ta = tf.transpose(tf.expand_dims(mask, -1), [1, 0, 2]) def _step_fn(prev, x): t, h_ = prev m = x[-1] v = x[:-1] h = cell(h_, v) h = m * h + (1. - m) * h_ return t + 1, h time = tf.constant(0, dtype=tf.int32, name="time") step_states = (time, init_state) step_vars = cache_inputs + [mask_ta] outputs = tf.scan(_step_fn, step_vars, initializer=step_states, parallel_iterations=32, swap_memory=sm) output_ta = outputs[1] output_state = outputs[1][-1] outputs = tf.transpose(output_ta, [1, 0, 2]) return (outputs, output_state), \ (cell.get_hidden(outputs), cell.get_hidden(output_state))
def _get_init_state(self, d, shape=None, x=None, scope=None): # gen init state vector # if no evidence x is provided, use zero initialization if x is None: assert shape is not None, "you should provide shape" if not isinstance(shape, (tuple, list)): shape = [shape] shape = shape + [d] return dtype.tf_to_float(tf.zeros(shape)) else: return linear(x, d, bias=True, ln=self.ln, scope="{}_init".format(scope or self.scope))
def encoder(source, params): mask = dtype.tf_to_float(tf.cast(source, tf.bool)) hidden_size = params.hidden_size source, mask = util.remove_invalid_seq(source, mask) embed_name = "embedding" if params.shared_source_target_embedding \ else "src_embedding" src_emb = tf.get_variable(embed_name, [params.src_vocab.size(), params.embed_size]) src_bias = tf.get_variable("bias", [params.embed_size]) inputs = tf.gather(src_emb, source) inputs = tf.nn.bias_add(inputs, src_bias) inputs = util.valid_apply_dropout(inputs, params.dropout) with tf.variable_scope("encoder"): x = inputs for layer in range(params.num_encoder_layer): with tf.variable_scope("layer_{}".format(layer)): # forward rnn with tf.variable_scope('forward'): outputs = rnn.rnn(params.cell, x, hidden_size, mask=mask, ln=params.layer_norm, sm=params.swap_memory) output_fw, state_fw = outputs[1] if layer == 0: # backward rnn with tf.variable_scope('backward'): if not params.caencoder: outputs = rnn.rnn(params.cell, tf.reverse(x, [1]), hidden_size, mask=tf.reverse(mask, [1]), ln=params.layer_norm, sm=params.swap_memory) output_bw, state_bw = outputs[1] else: outputs = rnn.cond_rnn(params.cell, tf.reverse(x, [1]), tf.reverse(output_fw, [1]), hidden_size, mask=tf.reverse(mask, [1]), ln=params.layer_norm, sm=params.swap_memory, num_heads=params.num_heads, one2one=True) output_bw, state_bw = outputs[1] output_bw = tf.reverse(output_bw, [1]) if not params.caencoder: y = tf.concat([output_fw, output_bw], -1) z = tf.concat([state_fw, state_bw], -1) else: y = output_bw z = state_bw else: y = output_fw z = state_fw y = func.linear(y, params.embed_size, ln=False, scope="ff") # short cut via residual connection if x.get_shape()[-1].value == y.get_shape()[-1].value: x = func.residual_fn(x, y, dropout=params.dropout) else: x = y if params.layer_norm: x = func.layer_norm(x, scope="ln") if params.embed_size != hidden_size: x = func.layer_norm(func.linear(x, hidden_size, scope="x_map")) with tf.variable_scope("decoder_initializer"): decoder_cell = rnn.get_cell(params.cell, hidden_size, ln=params.layer_norm) return { "encodes": x, "decoder_initializer": { "layer_{}".format(l): decoder_cell.get_init_state(x=z, scope="layer_{}".format(l)) for l in range(params.num_decoder_layer) }, "mask": mask }
def encoder(source, params): mask = dtype.tf_to_float(tf.cast(source, tf.bool)) hidden_size = params.hidden_size source, mask = util.remove_invalid_seq(source, mask) embed_name = "embedding" if params.shared_source_target_embedding \ else "src_embedding" src_emb = tf.get_variable(embed_name, [params.src_vocab.size(), params.embed_size]) src_bias = tf.get_variable("bias", [params.embed_size]) inputs = tf.gather(src_emb, source) inputs = tf.nn.bias_add(inputs, src_bias) inputs = util.valid_apply_dropout(inputs, params.dropout) with tf.variable_scope("encoder"): # forward rnn with tf.variable_scope('forward'): outputs = rnn.rnn(params.cell, inputs, hidden_size, mask=mask, ln=params.layer_norm, sm=params.swap_memory) output_fw, state_fw = outputs[1] # backward rnn with tf.variable_scope('backward'): if not params.caencoder: outputs = rnn.rnn(params.cell, tf.reverse(inputs, [1]), hidden_size, mask=tf.reverse(mask, [1]), ln=params.layer_norm, sm=params.swap_memory) output_bw, state_bw = outputs[1] else: outputs = rnn.cond_rnn(params.cell, tf.reverse(inputs, [1]), tf.reverse(output_fw, [1]), hidden_size, mask=tf.reverse(mask, [1]), ln=params.layer_norm, sm=params.swap_memory, num_heads=params.num_heads, one2one=True) output_bw, state_bw = outputs[1] output_bw = tf.reverse(output_bw, [1]) if not params.caencoder: source_encodes = tf.concat([output_fw, output_bw], -1) source_feature = tf.concat([state_fw, state_bw], -1) else: source_encodes = output_bw source_feature = state_bw with tf.variable_scope("decoder_initializer"): decoder_init = rnn.get_cell( params.cell, hidden_size, ln=params.layer_norm).get_init_state(x=source_feature) decoder_init = tf.tanh(decoder_init) return { "encodes": source_encodes, "decoder_initializer": decoder_init, "mask": mask }
def cond_rnn(cell_name, x, memory, d, init_state=None, mask=None, mem_mask=None, ln=False, sm=True, one2one=False, num_heads=1): """Self implemented conditional-RNN procedure, supporting mask trick""" # cell_name: gru, lstm or atr # x: input sequence embedding matrix, [batch, seq_len, dim] # memory: the conditional part # d: hidden dimension for rnn # mask: mask matrix, [batch, seq_len] # mem_mask: memory mask matrix, [batch, mem_seq_len] # ln: whether use layer normalization # init_state: the initial hidden states, for cache purpose # sm: whether apply swap memory during rnn scan # one2one: whether the memory is one-to-one mapping for x # num_heads: number of attention heads, multi-head attention # dp: variational dropout in_shape = util.shape_list(x) batch_size, time_steps = in_shape[:2] mem_shape = util.shape_list(memory) cell_lower = get_cell(cell_name, d, ln=ln, scope="{}_lower".format(cell_name)) cell_higher = get_cell(cell_name, d, ln=ln, scope="{}_higher".format(cell_name)) if init_state is None: init_state = cell_lower.get_init_state(shape=[batch_size]) if mask is None: mask = dtype.tf_to_float(tf.ones([batch_size, time_steps])) if mem_mask is None: mem_mask = dtype.tf_to_float(tf.ones([batch_size, mem_shape[1]])) # prepare projected encodes and inputs cache_inputs = cell_lower.fetch_states(x) cache_inputs = [tf.transpose(v, [1, 0, 2]) for v in list(cache_inputs)] if not one2one: proj_memories = linear(memory, mem_shape[-1], bias=False, ln=ln, scope="context_att") else: cache_memories = cell_higher.fetch_states(memory) cache_memories = [ tf.transpose(v, [1, 0, 2]) for v in list(cache_memories) ] mask_ta = tf.transpose(tf.expand_dims(mask, -1), [1, 0, 2]) init_context = dtype.tf_to_float(tf.zeros([batch_size, mem_shape[-1]])) init_weight = dtype.tf_to_float( tf.zeros([batch_size, num_heads, mem_shape[1]])) mask_pos = len(cache_inputs) def _step_fn(prev, x): t, h_, c_, a_ = prev if not one2one: m, v = x[mask_pos], x[:mask_pos] else: c, c_c, m, v = x[-1], x[mask_pos + 1:-1], x[mask_pos], x[:mask_pos] s = cell_lower(h_, v) s = m * s + (1. - m) * h_ if not one2one: vle = additive_attention(cell_lower.get_hidden(s), memory, mem_mask, mem_shape[-1], ln=ln, num_heads=num_heads, proj_memory=proj_memories, scope="attention") a, c = vle['weights'], vle['output'] c_c = cell_higher.fetch_states(c) else: a = tf.tile(tf.expand_dims(tf.range(time_steps), 0), [batch_size, 1]) a = dtype.tf_to_float(tf.equal(a, t)) a = tf.tile(tf.expand_dims(a, 1), [1, num_heads, 1]) a = tf.reshape(a, tf.shape(init_weight)) h = cell_higher(s, c_c) h = m * h + (1. - m) * s return t + 1, h, c, a time = tf.constant(0, dtype=tf.int32, name="time") step_states = (time, init_state, init_context, init_weight) step_vars = cache_inputs + [mask_ta] if one2one: step_vars += cache_memories + [memory] outputs = tf.scan(_step_fn, step_vars, initializer=step_states, parallel_iterations=32, swap_memory=sm) output_ta = outputs[1] context_ta = outputs[2] attention_ta = outputs[3] outputs = tf.transpose(output_ta, [1, 0, 2]) output_states = outputs[:, -1] contexts = tf.transpose(context_ta, [1, 0, 2]) attentions = tf.transpose(attention_ta, [1, 2, 0, 3]) return (outputs, output_states), \ (cell_higher.get_hidden(outputs), cell_higher.get_hidden(output_states)), \ contexts, attentions
def dot_attention(query, memory, mem_mask, hidden_size, ln=False, num_heads=1, cache=None, dropout=None, use_relative_pos=False, max_relative_position=16, out_map=True, scope=None, fuse_mask=None, decode_step=None): """ dotted attention model :param query: [batch_size, qey_len, dim] :param memory: [batch_size, seq_len, mem_dim] or None :param mem_mask: [batch_size, seq_len] :param hidden_size: attention space dimension :param ln: whether use layer normalization :param num_heads: attention head number :param dropout: attention dropout, default disable :param out_map: output additional mapping :param cache: cache-based decoding :param fuse_mask: aan mask during training, and timestep for testing :param max_relative_position: maximum position considered for relative embedding :param use_relative_pos: whether use relative position information :param decode_step: the time step of current decoding, 0-based :param scope: :return: a value matrix, [batch_size, qey_len, mem_dim] """ with tf.variable_scope(scope or "dot_attention", reuse=tf.AUTO_REUSE, dtype=tf.as_dtype(dtype.floatx())): if fuse_mask is not None: assert memory is not None, 'Fuse mechanism only applied with cross-attention' if cache and use_relative_pos: assert decode_step is not None, 'Decode Step must provide when use relative position encoding' if memory is None: # suppose self-attention from queries alone h = linear(query, hidden_size * 3, ln=ln, scope="qkv_map") q, k, v = tf.split(h, 3, -1) if cache is not None: k = tf.concat([cache['k'], k], axis=1) v = tf.concat([cache['v'], v], axis=1) cache = { 'k': k, 'v': v, } else: q = linear(query, hidden_size, ln=ln, scope="q_map") if cache is not None and ('mk' in cache and 'mv' in cache): k, v = cache['mk'], cache['mv'] else: k = linear(memory, hidden_size, ln=ln, scope="k_map") v = linear(memory, hidden_size, ln=ln, scope="v_map") if cache is not None: cache['mk'] = k cache['mv'] = v q = split_heads(q, num_heads) k = split_heads(k, num_heads) v = split_heads(v, num_heads) q *= (hidden_size // num_heads) ** (-0.5) q_shp = util.shape_list(q) k_shp = util.shape_list(k) v_shp = util.shape_list(v) q_len = q_shp[2] if decode_step is None else decode_step + 1 r_lst = None if decode_step is None else 1 # q * k => attention weights if use_relative_pos: r = rpr.get_relative_positions_embeddings( q_len, k_shp[2], k_shp[3], max_relative_position, name="rpr_keys", last=r_lst) logits = rpr.relative_attention_inner(q, k, r, transpose=True) else: logits = tf.matmul(q, k, transpose_b=True) if mem_mask is not None: logits += mem_mask weights = tf.nn.softmax(logits) dweights = util.valid_apply_dropout(weights, dropout) # weights * v => attention vectors if use_relative_pos: r = rpr.get_relative_positions_embeddings( q_len, k_shp[2], v_shp[3], max_relative_position, name="rpr_values", last=r_lst) o = rpr.relative_attention_inner(dweights, v, r, transpose=False) else: o = tf.matmul(dweights, v) o = combine_heads(o) if fuse_mask is not None: # This is for AAN, the important part is sharing v_map v_q = linear(query, hidden_size, ln=ln, scope="v_map") if cache is not None and 'aan' in cache: aan_o = (v_q + cache['aan']) / dtype.tf_to_float(fuse_mask + 1) else: # Simplified Average Attention Network aan_o = tf.matmul(fuse_mask, v_q) if cache is not None: if 'aan' not in cache: cache['aan'] = v_q else: cache['aan'] = v_q + cache['aan'] # Directly sum both self-attention and cross attention o = o + aan_o if out_map: o = linear(o, hidden_size, ln=ln, scope="o_map") results = { 'weights': weights, 'output': o, 'cache': cache } return results
def decoder(target, state, params): mask = dtype.tf_to_float(tf.cast(target, tf.bool)) hidden_size = params.hidden_size initializer = tf.random_normal_initializer(0.0, hidden_size**-0.5) is_training = ('decoder' not in state) if is_training: target, mask = util.remove_invalid_seq(target, mask) embed_name = "embedding" if params.shared_source_target_embedding \ else "tgt_embedding" tgt_emb = tf.get_variable(embed_name, [params.tgt_vocab.size(), params.embed_size], initializer=initializer) tgt_bias = tf.get_variable("bias", [params.embed_size]) inputs = tf.gather(tgt_emb, target) * (hidden_size**0.5) inputs = tf.nn.bias_add(inputs, tgt_bias) # shift if is_training: inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]]) inputs = inputs[:, :-1, :] inputs = func.add_timing_signal(inputs) else: inputs = tf.cond( tf.reduce_all(tf.equal(target, params.tgt_vocab.pad())), lambda: tf.zeros_like(inputs), lambda: inputs) mask = tf.ones_like(mask) inputs = func.add_timing_signal(inputs, time=dtype.tf_to_float(state['time'])) inputs = util.valid_apply_dropout(inputs, params.dropout) with tf.variable_scope("decoder"): x = inputs for layer in range(params.num_decoder_layer): if params.deep_transformer_init: layer_initializer = tf.variance_scaling_initializer( params.initializer_gain * (layer + 1)**-0.5, mode="fan_avg", distribution="uniform") else: layer_initializer = None with tf.variable_scope("layer_{}".format(layer), initializer=layer_initializer): with tf.variable_scope("average_attention"): x_fwds = [] for strategy in params.strategies: with tf.variable_scope(strategy): x_fwd = average_attention_strategy( strategy, x, mask, state, layer, params) x_fwds.append(x_fwd) x_fwd = tf.add_n(x_fwds) / len(x_fwds) # FFN activation if params.use_ffn: y = func.ffn_layer( x_fwd, params.filter_size, hidden_size, dropout=params.relu_dropout, ) else: y = x_fwd # Gating layer z = func.linear(tf.concat([x, y], axis=-1), hidden_size * 2, scope="z_project") i, f = tf.split(z, 2, axis=-1) y = tf.sigmoid(i) * x + tf.sigmoid(f) * y x = func.residual_fn(x, y, dropout=params.residual_dropout) x = func.layer_norm(x) with tf.variable_scope("cross_attention"): y = func.dot_attention( x, state['encodes'], func.attention_bias(state['mask'], "masking"), hidden_size, num_heads=params.num_heads, dropout=params.attention_dropout, cache=None if is_training else state['decoder']['state']['layer_{}'.format(layer)]) if not is_training: # mk, mv state['decoder']['state']['layer_{}'.format(layer)]\ .update(y['cache']) y = y['output'] x = func.residual_fn(x, y, dropout=params.residual_dropout) x = func.layer_norm(x) with tf.variable_scope("feed_forward"): y = func.ffn_layer( x, params.filter_size, hidden_size, dropout=params.relu_dropout, ) x = func.residual_fn(x, y, dropout=params.residual_dropout) x = func.layer_norm(x) feature = x if 'dev_decode' in state: feature = x[:, -1, :] embed_name = "tgt_embedding" if params.shared_target_softmax_embedding \ else "softmax_embedding" embed_name = "embedding" if params.shared_source_target_embedding \ else embed_name softmax_emb = tf.get_variable(embed_name, [params.tgt_vocab.size(), params.embed_size], initializer=initializer) feature = tf.reshape(feature, [-1, params.embed_size]) logits = tf.matmul(feature, softmax_emb, False, True) logits = tf.cast(logits, tf.float32) soft_label, normalizer = util.label_smooth(target, util.shape_list(logits)[-1], factor=params.label_smooth) centropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=soft_label) centropy -= normalizer centropy = tf.reshape(centropy, tf.shape(target)) mask = tf.cast(mask, tf.float32) per_sample_loss = tf.reduce_sum(centropy * mask, -1) / tf.reduce_sum( mask, -1) loss = tf.reduce_mean(per_sample_loss) # these mask tricks mainly used to deal with zero shapes, such as [0, 1] loss = tf.cond(tf.equal(tf.shape(target)[0], 0), lambda: tf.constant(0, dtype=tf.float32), lambda: loss) return loss, logits, state, per_sample_loss
def encoder(source, params): mask = dtype.tf_to_float(tf.cast(source, tf.bool)) hidden_size = params.hidden_size initializer = tf.random_normal_initializer(0.0, hidden_size**-0.5) source, mask = util.remove_invalid_seq(source, mask) embed_name = "embedding" if params.shared_source_target_embedding \ else "src_embedding" src_emb = tf.get_variable(embed_name, [params.src_vocab.size(), params.embed_size], initializer=initializer) src_bias = tf.get_variable("bias", [params.embed_size]) inputs = tf.gather(src_emb, source) * (hidden_size**0.5) inputs = tf.nn.bias_add(inputs, src_bias) inputs = func.add_timing_signal(inputs) inputs = util.valid_apply_dropout(inputs, params.dropout) with tf.variable_scope("encoder"): x = inputs for layer in range(params.num_encoder_layer): if params.deep_transformer_init: layer_initializer = tf.variance_scaling_initializer( params.initializer_gain * (layer + 1)**-0.5, mode="fan_avg", distribution="uniform") else: layer_initializer = None with tf.variable_scope("layer_{}".format(layer), initializer=layer_initializer): with tf.variable_scope("self_attention"): y = func.dot_attention(x, None, func.attention_bias( mask, "masking"), hidden_size, num_heads=params.num_heads, dropout=params.attention_dropout) y = y['output'] x = func.residual_fn(x, y, dropout=params.residual_dropout) x = func.layer_norm(x) with tf.variable_scope("feed_forward"): y = func.ffn_layer( x, params.filter_size, hidden_size, dropout=params.relu_dropout, ) x = func.residual_fn(x, y, dropout=params.residual_dropout) x = func.layer_norm(x) source_encodes = x x_shp = util.shape_list(x) return { "encodes": source_encodes, "decoder_initializer": { "layer_{}".format(l): { # plan aan "aan": dtype.tf_to_float(tf.zeros([x_shp[0], 1, hidden_size])), } for l in range(params.num_decoder_layer) }, "mask": mask }
def decoder(target, state, params): mask = dtype.tf_to_float(tf.cast(target, tf.bool)) hidden_size = params.hidden_size initializer = tf.random_normal_initializer(0.0, hidden_size**-0.5) is_training = ('decoder' not in state) if is_training: target, mask = util.remove_invalid_seq(target, mask) embed_name = "embedding" if params.shared_source_target_embedding \ else "tgt_embedding" tgt_emb = tf.get_variable(embed_name, [params.tgt_vocab.size(), params.embed_size], initializer=initializer) tgt_bias = tf.get_variable("bias", [params.embed_size]) inputs = tf.gather(tgt_emb, target) * (hidden_size**0.5) inputs = tf.nn.bias_add(inputs, tgt_bias) # shift if is_training: inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]]) inputs = inputs[:, :-1, :] inputs = func.add_timing_signal(inputs) else: inputs = tf.cond( tf.reduce_all(tf.equal(target, params.tgt_vocab.pad())), lambda: tf.zeros_like(inputs), lambda: inputs) mask = tf.ones_like(mask) inputs = func.add_timing_signal(inputs, time=dtype.tf_to_float(state['time'])) inputs = util.valid_apply_dropout(inputs, params.dropout) # Applying L0Drop # -------- source_memory = state["encodes"] source_mask = state["mask"] # source_pruning: log alpha_i = x_i w^T source_pruning = func.linear(source_memory, 1, scope="source_pruning") if is_training: # training source_memory, l0_mask = l0norm.var_train( (source_memory, source_pruning)) l0_norm_loss = tf.squeeze(l0norm.l0_norm(source_pruning), -1) l0_norm_loss = tf.reduce_sum(l0_norm_loss * source_mask, -1) / tf.reduce_sum(source_mask, -1) l0_norm_loss = tf.reduce_mean(l0_norm_loss) l0_norm_loss = l0norm.l0_regularization_loss( l0_norm_loss, reg_scalar=params.l0_norm_reg_scalar, start_reg_ramp_up=params.l0_norm_start_reg_ramp_up, end_reg_ramp_up=params.l0_norm_end_reg_ramp_up, warm_up=params.l0_norm_warm_up, ) # force the model to only attend to unmasked position source_mask = dtype.tf_to_float( tf.cast(tf.squeeze(l0_mask, -1), tf.bool)) * source_mask else: # evaluation source_memory, l0_mask = l0norm.var_eval( (source_memory, source_pruning)) l0_norm_loss = 0.0 source_memory, source_mask, count_mask = extract_encodes( source_memory, source_mask, l0_mask) count_mask = tf.expand_dims(tf.expand_dims(count_mask, 1), 1) # -------- with tf.variable_scope("decoder"): x = inputs for layer in range(params.num_decoder_layer): if params.deep_transformer_init: layer_initializer = tf.variance_scaling_initializer( params.initializer_gain * (layer + 1)**-0.5, mode="fan_avg", distribution="uniform") else: layer_initializer = None with tf.variable_scope("layer_{}".format(layer), initializer=layer_initializer): with tf.variable_scope("self_attention"): y = func.dot_attention( x, None, func.attention_bias(tf.shape(mask)[1], "causal"), hidden_size, num_heads=params.num_heads, dropout=params.attention_dropout, cache=None if is_training else state['decoder']['state']['layer_{}'.format(layer)]) if not is_training: # k, v state['decoder']['state']['layer_{}'.format(layer)] \ .update(y['cache']) y = y['output'] x = func.residual_fn(x, y, dropout=params.residual_dropout) x = func.layer_norm(x) with tf.variable_scope("cross_attention"): if is_training: y = func.dot_attention( x, source_memory, func.attention_bias(source_mask, "masking"), hidden_size, num_heads=params.num_heads, dropout=params.attention_dropout, ) else: y = dot_attention(x, source_memory, func.attention_bias( source_mask, "masking"), hidden_size, count_mask=count_mask, num_heads=params.num_heads, dropout=params.attention_dropout, cache=state['decoder']['state'][ 'layer_{}'.format(layer)]) # mk, mv state['decoder']['state']['layer_{}'.format(layer)] \ .update(y['cache']) y = y['output'] x = func.residual_fn(x, y, dropout=params.residual_dropout) x = func.layer_norm(x) with tf.variable_scope("feed_forward"): y = func.ffn_layer( x, params.filter_size, hidden_size, dropout=params.relu_dropout, ) x = func.residual_fn(x, y, dropout=params.residual_dropout) x = func.layer_norm(x) feature = x if 'dev_decode' in state: feature = x[:, -1, :] embed_name = "tgt_embedding" if params.shared_target_softmax_embedding \ else "softmax_embedding" embed_name = "embedding" if params.shared_source_target_embedding \ else embed_name softmax_emb = tf.get_variable(embed_name, [params.tgt_vocab.size(), params.embed_size], initializer=initializer) feature = tf.reshape(feature, [-1, params.embed_size]) logits = tf.matmul(feature, softmax_emb, False, True) logits = tf.cast(logits, tf.float32) soft_label, normalizer = util.label_smooth(target, util.shape_list(logits)[-1], factor=params.label_smooth) centropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=soft_label) centropy -= normalizer centropy = tf.reshape(centropy, tf.shape(target)) mask = tf.cast(mask, tf.float32) per_sample_loss = tf.reduce_sum(centropy * mask, -1) / tf.reduce_sum( mask, -1) loss = tf.reduce_mean(per_sample_loss) loss = loss + l0_norm_loss # these mask tricks mainly used to deal with zero shapes, such as [0, 1] loss = tf.cond(tf.equal(tf.shape(target)[0], 0), lambda: tf.constant(0, tf.float32), lambda: loss) return loss, logits, state, per_sample_loss
def deep_att_dec_rnn(cell_name, x, memory, d, init_state=None, mask=None, mem_mask=None, ln=False, sm=True, depth=1, num_heads=1): """Self implemented conditional-RNN procedure, supporting mask trick""" # cell_name: gru, lstm or atr # x: input sequence embedding matrix, [batch, seq_len, dim] # memory: the conditional part # d: hidden dimension for rnn # mask: mask matrix, [batch, seq_len] # mem_mask: memory mask matrix, [batch, mem_seq_len] # ln: whether use layer normalization # init_state: the initial hidden states, for cache purpose # sm: whether apply swap memory during rnn scan # depth: depth for the decoder in deep attention # num_heads: number of attention heads, multi-head attention # dp: variational dropout in_shape = util.shape_list(x) batch_size, time_steps = in_shape[:2] mem_shape = util.shape_list(memory) cell_lower = rnn.get_cell(cell_name, d, ln=ln, scope="{}_lower".format(cell_name)) cells_higher = [] for layer in range(depth): cell_higher = rnn.get_cell(cell_name, d, ln=ln, scope="{}_higher_{}".format( cell_name, layer)) cells_higher.append(cell_higher) if init_state is None: init_state = cell_lower.get_init_state(shape=[batch_size]) if mask is None: mask = dtype.tf_to_float(tf.ones([batch_size, time_steps])) if mem_mask is None: mem_mask = dtype.tf_to_float(tf.ones([batch_size, mem_shape[1]])) # prepare projected encodes and inputs cache_inputs = cell_lower.fetch_states(x) cache_inputs = [tf.transpose(v, [1, 0, 2]) for v in list(cache_inputs)] proj_memories = func.linear(memory, mem_shape[-1], bias=False, ln=ln, scope="context_att") mask_ta = tf.transpose(tf.expand_dims(mask, -1), [1, 0, 2]) init_context = dtype.tf_to_float( tf.zeros([batch_size, depth, mem_shape[-1]])) init_weight = dtype.tf_to_float( tf.zeros([batch_size, depth, num_heads, mem_shape[1]])) mask_pos = len(cache_inputs) def _step_fn(prev, x): t, h_, c_, a_ = prev m, v = x[mask_pos], x[:mask_pos] # the first decoder rnn subcell, composing previous hidden state with the current word embedding s_ = cell_lower(h_, v) s_ = m * s_ + (1. - m) * h_ atts, att_ctxs = [], [] for layer in range(depth): # perform attention prev_cell = cell_lower if layer == 0 else cells_higher[layer - 1] vle = func.additive_attention( prev_cell.get_hidden(s_), memory, mem_mask, mem_shape[-1], ln=ln, num_heads=num_heads, proj_memory=proj_memories, scope="deep_attention_{}".format(layer)) a, c = vle['weights'], vle['output'] atts.append(tf.expand_dims(a, 1)) att_ctxs.append(tf.expand_dims(c, 1)) # perform next-level recurrence c_c = cells_higher[layer].fetch_states(c) ss_ = cells_higher[layer](s_, c_c) s_ = m * ss_ + (1. - m) * s_ h = s_ a = tf.concat(atts, axis=1) c = tf.concat(att_ctxs, axis=1) return t + 1, h, c, a time = tf.constant(0, dtype=tf.int32, name="time") step_states = (time, init_state, init_context, init_weight) step_vars = cache_inputs + [mask_ta] outputs = tf.scan(_step_fn, step_vars, initializer=step_states, parallel_iterations=32, swap_memory=sm) output_ta = outputs[1] context_ta = outputs[2] attention_ta = outputs[3] outputs = tf.transpose(output_ta, [1, 0, 2]) output_states = outputs[:, -1] # batch x target length x depth x mem-dimension contexts = tf.transpose(context_ta, [1, 0, 2, 3]) # batch x num_heads x depth x target length x source length attentions = tf.transpose(attention_ta, [1, 3, 2, 0, 4]) return (outputs, output_states), \ (cells_higher[-1].get_hidden(outputs), cells_higher[-1].get_hidden(output_states)), \ contexts, attentions
def decoder(target, state, params): mask = dtype.tf_to_float(tf.cast(target, tf.bool)) hidden_size = params.hidden_size is_training = ('decoder' not in state) if is_training: target, mask = util.remove_invalid_seq(target, mask) embed_name = "embedding" if params.shared_source_target_embedding \ else "tgt_embedding" tgt_emb = tf.get_variable(embed_name, [params.tgt_vocab.size(), params.embed_size]) tgt_bias = tf.get_variable("bias", [params.embed_size]) inputs = tf.gather(tgt_emb, target) inputs = tf.nn.bias_add(inputs, tgt_bias) # shift if is_training: inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]]) inputs = inputs[:, :-1, :] else: inputs = tf.cond( tf.reduce_all(tf.equal(target, params.tgt_vocab.pad())), lambda: tf.zeros_like(inputs), lambda: inputs) mask = tf.ones_like(mask) inputs = util.valid_apply_dropout(inputs, params.dropout) with tf.variable_scope("decoder"): x = inputs for layer in range(params.num_decoder_layer): with tf.variable_scope("layer_{}".format(layer)): init_state = state["decoder_initializer"]["layer_{}".format( layer)] if not is_training: init_state = state["decoder"]["state"]["layer_{}".format( layer)] if layer == 0 or params.use_deep_att: returns = rnn.cond_rnn(params.cell, x, state["encodes"], hidden_size, init_state=init_state, mask=mask, num_heads=params.num_heads, mem_mask=state["mask"], ln=params.layer_norm, sm=params.swap_memory, one2one=False) (_, hidden_state), (outputs, _), contexts, attentions = returns c = contexts else: if params.caencoder: returns = rnn.cond_rnn(params.cell, x, c, hidden_size, init_state=init_state, mask=mask, mem_mask=mask, ln=params.layer_norm, sm=params.swap_memory, num_heads=params.num_heads, one2one=True) (_, hidden_state), (outputs, _), contexts, attentions = returns else: outputs = rnn.rnn(params.cell, tf.concat([x, c], -1), hidden_size, mask=mask, init_state=init_state, ln=params.layer_norm, sm=params.swap_memory) outputs, hidden_state = outputs[1] if not is_training: state['decoder']['state']['layer_{}'.format( layer)] = hidden_state y = func.linear(outputs, params.embed_size, ln=False, scope="ff") # short cut via residual connection if x.get_shape()[-1].value == y.get_shape()[-1].value: x = func.residual_fn(x, y, dropout=params.dropout) else: x = y if params.layer_norm: x = func.layer_norm(x, scope="ln") if params.dl4mt_redict: feature = func.linear(tf.concat([x, c], -1), params.embed_size, ln=params.layer_norm, scope="ff") feature = tf.nn.tanh(feature) feature = util.valid_apply_dropout(feature, params.dropout) else: feature = x if 'dev_decode' in state: feature = x[:, -1, :] embed_name = "tgt_embedding" if params.shared_target_softmax_embedding \ else "softmax_embedding" embed_name = "embedding" if params.shared_source_target_embedding \ else embed_name softmax_emb = tf.get_variable(embed_name, [params.tgt_vocab.size(), params.embed_size]) feature = tf.reshape(feature, [-1, params.embed_size]) logits = tf.matmul(feature, softmax_emb, False, True) logits = tf.cast(logits, tf.float32) soft_label, normalizer = util.label_smooth(target, util.shape_list(logits)[-1], factor=params.label_smooth) centropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=soft_label) centropy -= normalizer centropy = tf.reshape(centropy, tf.shape(target)) mask = tf.cast(mask, tf.float32) per_sample_loss = tf.reduce_sum(centropy * mask, -1) / tf.reduce_sum( mask, -1) loss = tf.reduce_mean(per_sample_loss) # these mask tricks mainly used to deal with zero shapes, such as [0, 1] loss = tf.cond(tf.equal(tf.shape(target)[0], 0), lambda: tf.constant(0, dtype=tf.float32), lambda: loss) return loss, logits, state, per_sample_loss
def encoder(source, params): mask = dtype.tf_to_float(tf.cast(source, tf.bool)) hidden_size = params.hidden_size source, mask = util.remove_invalid_seq(source, mask) # extract source word embedding and apply dropout embed_name = "embedding" if params.shared_source_target_embedding \ else "src_embedding" src_emb = tf.get_variable(embed_name, [params.src_vocab.size(), params.embed_size]) src_bias = tf.get_variable("bias", [params.embed_size]) inputs = tf.gather(src_emb, source) inputs = tf.nn.bias_add(inputs, src_bias) inputs = util.valid_apply_dropout(inputs, params.dropout) # the encoder module used in the deep attention paper with tf.variable_scope("encoder"): # x: embedding input, h: the hidden state x = inputs h = 0 z = 0 for layer in range(params.num_encoder_layer + 1): with tf.variable_scope("layer_{}".format(layer)): if layer == 0: # for the first layer, we perform a normal rnn layer to collect context information outputs = rnn.rnn(params.cell, x, hidden_size, mask=mask, ln=params.layer_norm, sm=params.swap_memory) h = outputs[1][0] else: # for deeper encoder layers, we incorporate both embedding input and previous inversed hidden # state sequence as input. # the embedding informs current input while hidden state tells future context is_reverse = (layer % 2 == 1) outputs = rnn.cond_rnn( params.cell, tf.reverse(x, [1]) if is_reverse else x, tf.reverse(h, [1]) if is_reverse else h, hidden_size, mask=tf.reverse(mask, [1]) if is_reverse else mask, ln=params.layer_norm, sm=params.swap_memory, num_heads=params.num_heads, one2one=True) h = outputs[1][0] h = tf.reverse(h, [1]) if is_reverse else h # the final hidden state used for decoder state initialization z = outputs[1][1] with tf.variable_scope("decoder_initializer"): decoder_cell = rnn.get_cell(params.cell, hidden_size, ln=params.layer_norm) return { "encodes": h, "decoder_initializer": { 'layer': decoder_cell.get_init_state(x=z, scope="dec_init_state") }, "mask": mask }
def decoder(target, state, params): mask = dtype.tf_to_float(tf.cast(target, tf.bool)) hidden_size = params.hidden_size is_training = ('decoder' not in state) # handling target-side word embedding, including shift-padding for training embed_name = "embedding" if params.shared_source_target_embedding \ else "tgt_embedding" tgt_emb = tf.get_variable(embed_name, [params.tgt_vocab.size(), params.embed_size]) tgt_bias = tf.get_variable("bias", [params.embed_size]) inputs = tf.gather(tgt_emb, target) inputs = tf.nn.bias_add(inputs, tgt_bias) # shift if is_training: inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]]) inputs = inputs[:, :-1, :] else: inputs = tf.cond( tf.reduce_all(tf.equal(target, params.tgt_vocab.pad())), lambda: tf.zeros_like(inputs), lambda: inputs) mask = tf.ones_like(mask) inputs = util.valid_apply_dropout(inputs, params.dropout) with tf.variable_scope("decoder"): x = inputs init_state = state["decoder_initializer"]["layer"] if not is_training: init_state = state["decoder"]["state"]["layer"] returns = deep_att_dec_rnn(params.cell, x, state["encodes"], hidden_size, init_state=init_state, mask=mask, num_heads=params.num_heads, mem_mask=state["mask"], ln=params.layer_norm, sm=params.swap_memory, depth=params.num_decoder_layer) (_, hidden_state), (outputs, _), contexts, attentions = returns if not is_training: state['decoder']['state']['layer'] = hidden_state x = outputs cshp = util.shape_list(contexts) c = tf.reshape(contexts, [cshp[0], cshp[1], cshp[2] * cshp[3]]) feature = func.linear(tf.concat([x, c, inputs], -1), params.embed_size, ln=params.layer_norm, scope="ff") feature = tf.nn.tanh(feature) feature = util.valid_apply_dropout(feature, params.dropout) if 'dev_decode' in state: feature = feature[:, -1, :] embed_name = "tgt_embedding" if params.shared_target_softmax_embedding \ else "softmax_embedding" embed_name = "embedding" if params.shared_source_target_embedding \ else embed_name softmax_emb = tf.get_variable(embed_name, [params.tgt_vocab.size(), params.embed_size]) feature = tf.reshape(feature, [-1, params.embed_size]) logits = tf.matmul(feature, softmax_emb, False, True) logits = tf.cast(logits, tf.float32) soft_label, normalizer = util.label_smooth(target, util.shape_list(logits)[-1], factor=params.label_smooth) centropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=soft_label) centropy -= normalizer centropy = tf.reshape(centropy, tf.shape(target)) mask = tf.cast(mask, tf.float32) per_sample_loss = tf.reduce_sum(centropy * mask, -1) / tf.reduce_sum( mask, -1) loss = tf.reduce_mean(per_sample_loss) # these mask tricks mainly used to deal with zero shapes, such as [0, 1] loss = tf.cond(tf.equal(tf.shape(target)[0], 0), lambda: tf.constant(0, dtype=tf.float32), lambda: loss) return loss, logits, state, per_sample_loss