def symbols_to_logits_fn(input_symbols, i, state): # [batch_size, decoded_ids] to [batch_size, vocab_size] leng = shape_list(input_symbols)[1] pos_embed = encoder.vocab_size + tf.range(leng) pos_embed = tf.tile([pos_embed], [shape_list(input_symbols)[0], 1]) inp = tf.pad(tf.stack([input_symbols, pos_embed], -1), [[0,0], [0, 1], [0, 0]] ) target_feat_state = config.base_model.get_featurizer( inp, encoder=encoder, config=config, train=train, encoder_state=({**state["featurizer_state"], "embed_weights":embed_weights} if state else None) ) output_state = language_model( X=inp, M=tf.sequence_mask(target_feat_state["eos_idx"] + 1, maxlen=leng + 1, dtype=tf.float32), # +1 because we want to predict the clf token as an eos token embed_weights=embed_weights[:encoder.vocab_size, :], config=config, reuse=reuse, train=train, hidden=target_feat_state["sequence_features"] # deal with state ) if state is None: return output_state["logits"][:, i, :] else: return output_state["logits"][:, i, :], state
def dynamic_convolution(inp, inp2=None, n_heads=8, kernel_size=4, padding="causal"): batch, seq, n_channels = shape_list(inp) kernel_size_inc_ra = kernel_size if inp2 is not None: kernel_size_inc_ra += shape_list(inp2)[2] weights_linear = tf.reshape( linear( inp, n_heads * kernel_size_inc_ra, "kernel_machine", ), [batch, seq, kernel_size_inc_ra, n_heads ], # batch time heads, kernel_size, 1 ) weights_linear = tf.nn.softmax(weights_linear, 2) if inp2 is not None: weights_ra = weights_linear[:, :, kernel_size:] weights_linear = weights_linear[:, :, :kernel_size] ra_dynamic = dynamic_conv_on_ra_out(inp2, weights_ra) else: ra_dynamic = 0.0 if indico_ops.BUILT and tf.test.is_gpu_available(): # There is no CPU version of this op we need to check this. dynamic_conv_fn = indico_ops.dynamic_convolution_op else: dynamic_conv_fn = dynamic_conv_cpu return dynamic_conv_fn(inp, weights_linear, padding=padding) + ra_dynamic
def language_model(*, X, M, embed_weights, hidden, config, reuse=None): """ A language model output and loss for the language modelling objective described in the original finetune paper. This language model uses weights that are tied to the input embedding. :param X: The raw token ids fed to the featurizer. :param M: A loss mask, with 1's where losses should be counted and 0's elsewhere. :param embed_weights: The word embedding matrix, normally the one returned by the featurizer. :param hidden: Output of the featurizer. :param config: A config object. :param reuse: A Flag passed through to the tf.variable_scope context manager. :return: A dict containing: logits: The un-normalised log-probabilities over each word in the vocabulary. loss: The masked language modelling loss. """ X = merge_leading_dims(X, 3) M = merge_leading_dims(M, 2) hidden = merge_leading_dims(hidden, 3) batch, seq, _ = shape_list(X) with tf.variable_scope('model/language-model', reuse=reuse): # language model ignores last hidden state because we don't have a target sliced_hidden = hidden[:, :-1] lm_h = tf.reshape( sliced_hidden, [-1, config.n_embed ]) # [batch, seq_len, embed] --> [batch * seq_len, embed] lm_logits = tf.matmul(lm_h, embed_weights, transpose_b=True) # tied weights lm_logits = tf.reshape( lm_logits, [batch, seq - 1, tf.shape(embed_weights)[0]]) lm_losses = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=lm_logits, labels=X[:, 1:, 0]) perplexity = tf.reduce_sum(tf.exp(lm_losses) * M[:, 1:], 1) / tf.reduce_sum(M[:, 1:], 1) lm_losses = tf.reshape( lm_losses, [shape_list(X)[0], shape_list(X)[1] - 1]) # tf.maximum op prevents divide by zero error when mask is all 0s lm_losses = tf.reduce_sum(lm_losses * M[:, 1:], 1) / tf.maximum( tf.reduce_sum(M[:, 1:], 1), 1) lm_logits_shape = shape_list(lm_logits) sliced_hidden_shape = shape_list(sliced_hidden) return { 'logits': tf.reshape(lm_logits, shape=sliced_hidden_shape[:-1] + [lm_logits_shape[-1]]), 'losses': lm_losses, 'perplexity': perplexity }
def dynamic_conv_on_ra_out(ra_out, weights): """ ra_out: batch, length, ra_size, channels weights: atch, Time, FilterWidth, Heads """ batch, seq, ra_depth, n_channels = shape_list(ra_out) _, _, kernel_width, n_heads = shape_list(weights) assert n_channels % n_heads == 0 assert ra_depth == kernel_width h = n_channels // n_heads unfolded = tf.reshape(ra_out, [batch, seq, ra_depth, n_heads, h]) weights_expanded = tf.expand_dims(weights, 4) return tf.reshape(tf.reduce_sum(weights_expanded * unfolded, 2), [batch, seq, n_channels])
def linear(inp, output_dim, layer_name): with tf.variable_scope(layer_name): nx = shape_list(inp)[-1] if output_dim is None: output_dim = nx W = tf.get_variable(name="W", shape=[nx, output_dim], initializer=tf.initializers.glorot_normal()) if inp.dtype == tf.float16: W = tf.cast(W, tf.float16) return tf.reshape( tf.matmul(tf.reshape(inp, [-1, nx]), tf.reshape(W, [nx, output_dim])), shape_list(inp)[:-1] + [output_dim])
def mask_attn_weights(w): # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. _, _, nd, ns = shape_list(w) b = attention_mask(nd, ns, dtype=w.dtype) b = tf.reshape(b, [1, 1, nd, ns]) w = w * b - tf.cast(1e10, w.dtype) * (1 - b) return w
def block( x, n_head, act_fn, adptr_size, resid_pdrop, attn_pdrop, scope, train=False, scale=False, explain=False, ): with tf.variable_scope(scope): nx = shape_list(x)[-1] a = attn( x, "attn", nx, n_head, resid_pdrop, attn_pdrop, train=train, scale=scale, explain=explain, ) if adptr_size is not None: with tf.variable_scope("attn_adapter"): a = adapter(a, adptr_size, nx, train) n = norm(x + a, "ln_1") m = mlp(n, "mlp", nx * 4, act_fn, resid_pdrop, train=train) if adptr_size is not None: with tf.variable_scope("dense_adapter"): m = adapter(m, adptr_size, nx, train) h = norm(n + m, "ln_2") return h
def normal_1d_conv_block(X, kernel_width, layer_name, use_fp16, dilation=1, output_dim=None, causal=True): # layer_input shape = #batch, seq, embed_dim or batch, channels, seq, embed_dim with tf.variable_scope(layer_name): # Pad kernel_width (word_wise) - 1 to stop future viewing. left_pad = (kernel_width - 1) * dilation if causal: paddings = [[0, 0], [left_pad, 0], [0, 0]] else: paddings = [[0, 0], [left_pad // 2, left_pad - (left_pad // 2)], [0, 0]] if kernel_width > 1: padded_input = tf.pad(X, paddings, "CONSTANT") else: padded_input = X nx = shape_list(X)[-1] if output_dim is None: output_dim = nx W = tf.get_variable(name="W", shape=[kernel_width, nx, output_dim], initializer=tf.initializers.glorot_normal()) b = tf.get_variable(name="B", shape=[output_dim], initializer=tf.initializers.constant(0.0)) if use_fp16: W = tf.cast(W, tf.float16) b = tf.cast(b, tf.float16) conv = causal_conv(padded_input, W, dilation) conv = tf.nn.bias_add(conv, b) return conv
def dynamic_conv_cpu(inp, weights, padding="causal"): """ inp : Batch Time Channels weights: Batch, Time, FilterWidth, Heads padding: causal or same """ batch, seq, n_channels = shape_list(inp) _, _, kernel_width, n_heads = shape_list(weights) assert n_channels % n_heads == 0 h = n_channels // n_heads unfolded = unfold(inp, kernel_width, padding="causal") unfolded = tf.reshape(unfolded, [batch, seq, kernel_size, n_heads, h]) weights_expanded = tf.expand_dims(weights, 4) return tf.reshape(tf.reduce_sum(weights_expanded * unfolded, 2), shape_list(inp))
def block(x, n_head, act_fn, adptr_size, resid_pdrop, attn_pdrop, scope, train=False, scale=False, explain=False): with tf.variable_scope(scope): nx = shape_list(x)[-1] a = attn(x, 'attn', nx, n_head, resid_pdrop, attn_pdrop, train=train, scale=scale, explain=explain) if adptr_size is not None: with tf.variable_scope('attn_adapter'): a = adapter(a, adptr_size, nx, train) n = norm(x + a, 'ln_1') m = mlp(n, 'mlp', nx * 4, act_fn, resid_pdrop, train=train) if adptr_size is not None: with tf.variable_scope('dense_adapter'): m = adapter(m, adptr_size, nx, train) h = norm(n + m, 'ln_2') return h
def block(x, n_head, act_fn, resid_pdrop, attn_pdrop, scope, train=False, scale=False): with tf.variable_scope(scope): nx = shape_list(x)[-1] a = attn(x, 'attn', nx, n_head, resid_pdrop, attn_pdrop, train=train, scale=scale) n = norm(x + a, 'ln_1') m = mlp(n, 'mlp', nx * 4, act_fn, resid_pdrop, train=train) h = norm(n + m, 'ln_2') return h
def mlp(x, scope, n_state, act_fn, resid_pdrop, train=False): with tf.variable_scope(scope): nx = shape_list(x)[-1] act = act_fns[act_fn] h = act(conv1d(x, "c_fc", n_state, 1, train=train)) h2 = conv1d(h, "c_proj", nx, 1, train=train) h2 = dropout(h2, resid_pdrop, train) return h2
def mask_pad(w, lengths): batch = shape_list(lengths)[0] maxlen = tf.cast(tf.reduce_max(lengths), tf.int32) seq_mask = tf.reshape(tf.sequence_mask(lengths, maxlen=maxlen), [batch, 1, 1, maxlen]) b = tf.cast(seq_mask, tf.float32) w = w * b + -1e9 * (1 - b) return w
def language_model(*, X, M, embed_weights, hidden, config, reuse=None, train=False): """ A language model output and loss for the language modelling objective described in the original finetune paper. This language model uses weights that are tied to the input embedding. :param X: The raw token ids fed to the featurizer. :param M: A loss mask, with 1's where losses should be counted and 0's elsewhere. :param embed_weights: The word embedding matrix, normally the one returned by the featurizer. :param hidden: Output of the featurizer. :param config: A config object. :param reuse: A Flag passed through to the tf.variable_scope context manager. :return: A dict containing: logits: The un-normalised log-probabilities over each word in the vocabulary. loss: The masked language modelling loss. """ X = merge_leading_dims(X, 3) M = merge_leading_dims(M, 2) hidden = merge_leading_dims(hidden, 3) batch, seq, _ = shape_list(X) vocab_size, hidden_dim = shape_list(embed_weights) with tf.variable_scope('model/language-model', reuse=reuse): # language model ignores last hidden state because we don't have a target lm_h = tf.reshape(hidden, [-1, config.n_embed]) # [batch, seq_len, embed] --> [batch * seq_len, embed] lm_logits = tf.matmul(lm_h, embed_weights, transpose_b=True) # tied weights lm_logits = tf.cast(lm_logits, tf.float32) hidden_shape = tf.shape(hidden) logits = tf.reshape(lm_logits, shape=tf.concat([hidden_shape[:-1], [vocab_size]], axis=0)) lm_logits_offset = tf.reshape(logits[:, :-1], [-1, vocab_size]) lm_losses = tf.losses.sparse_softmax_cross_entropy( logits=lm_logits_offset, labels=tf.reshape(X[:, 1:, 0], [-1]), weights=tf.reshape(M[:, 1:], [-1]) ) perplexity = tf.reduce_sum(tf.exp(lm_losses) * M[:, 1:], 1) / tf.reduce_sum(M[:, 1:], 1) return { "logits": logits, "losses": lm_losses, "perplexity": perplexity, }
def norm(x, scope, axis=[-1], e=1e-5): with tf.variable_scope(scope): n_state = shape_list(x)[-1] g = tf.get_variable("g", [n_state], initializer=tf.constant_initializer(1)) b = tf.get_variable("b", [n_state], initializer=tf.constant_initializer(0)) u = tf.reduce_mean(x, axis=axis, keepdims=True) s = tf.reduce_mean(tf.square(x - u), axis=axis, keepdims=True) x = (x - u) * tf.rsqrt(s + e) x = x * g + b return x
def add_explain_tokens(X, max_length, pool_idx): flat_x = tf.reshape(X[:, :, :1], [-1, 1]) flat_pos = tf.minimum(X[:, :, 1:] + 1, max_length - 1) # + 1 to offset for start token clf_tok = tf.gather( flat_x, tf.range(shape_list(X)[0], dtype=tf.int32) * max_length + pool_idx) clf_tok_x_seq = tf.tile(tf.expand_dims(clf_tok, 1), [1, max_length, 1]) clf_tok_x_seq_w_pos = tf.concat((clf_tok_x_seq, flat_pos), -1) return tf.concat((X, clf_tok_x_seq_w_pos), 1)
def attn_weights(q, k, v, attn_pdrop, train=False, scale=False, mask=True): w = tf.matmul(q, k) if scale: n_state = shape_list(v)[-1] w = w * tf.rsqrt(tf.cast(n_state, tf.float32)) if mask: w = mask_attn_weights(w) w = tf.nn.softmax(w) return w
def conv1d(x, scope, nf, *, w_init_stdev=0.02): with tf.variable_scope(scope): *start, nx = shape_list(x) w = tf.get_variable( 'w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev)) b = tf.get_variable('b', [nf], initializer=tf.constant_initializer(0)) c = tf.reshape( tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf])) + b, start + [nf]) return c
def conv1d(x, scope, nf, rf, w_init=tf.random_normal_initializer(stddev=0.02), b_init=tf.constant_initializer(0), pad='VALID', train=False): with tf.variable_scope(scope): nx = shape_list(x)[-1] w = tf.get_variable("w", [rf, nx, nf], initializer=w_init) b = tf.get_variable("b", [nf], initializer=b_init) if rf == 1: # faster 1x1 conv c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf])) + b, shape_list(x)[:-1] + [nf]) else: # was used to train LM c = tf.nn.conv1d(x, w, stride=1, padding=pad) + b return c
def dynamic_convolution_op(inp, weights, padding="causal"): """ inp : Batch Time Channels weights: Batch, Time, FilterWidth, Heads padding: causal or same """ batch, seq, n_channels = shape_list(inp) _, _, kernel_width, n_heads = shape_list(weights) assert n_channels % n_heads == 0 if padding.lower() == "causal": padding_l = kernel_width - 1 elif padding.lower() == "same": padding_l = kernel_width // 2 inp_formatted = tf.transpose(inp, [0, 2, 1]) weights_formatted = tf.transpose(weights, [0, 3, 2, 1]) return tf.transpose(kernels_module.dynamic_convolution(inp_formatted, weights_formatted, padding_l), [0, 2, 1])
def _merge_beam_dim(tensor): """Reshapes first two dimensions in to single dimension. Args: tensor: Tensor to reshape of shape [A, B, ...] Returns: Reshaped tensor of shape [A*B, ...] """ shape = shape_list(tensor) shape[0] *= shape[1] # batch -> batch * beam_size shape.pop(1) # Remove beam dim return tf.reshape(tensor, shape)
def masked_language_model(*, X, M, mlm_weights, mlm_positions, mlm_ids, embed_weights, hidden, config, reuse=None, train=False): X = merge_leading_dims(X, 3) M = merge_leading_dims(M, 2) hidden = merge_leading_dims(hidden, 3) batch, seq, _ = shape_list(X) with tf.variable_scope('model/masked-language-model'): gathered_hidden = gather_indexes(hidden, mlm_positions) final_proj = tf.layers.dense( gathered_hidden, units=config.n_embed, activation=act_fns[config.act_fn], kernel_initializer=tf.random_normal_initializer(stddev=config.weight_stddev), name='dense' ) normed_proj = norm(final_proj, 'LayerNorm') n_vocab = shape_list(embed_weights)[0] output_bias = tf.get_variable( "output_bias", shape=[n_vocab], initializer=tf.zeros_initializer() ) logits = tf.matmul(normed_proj, embed_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) mlm_ids = tf.reshape(mlm_ids, [-1]) mlm_weights = tf.reshape(mlm_weights, [-1]) log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(mlm_ids, depth=n_vocab, dtype=tf.float32) per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) numerator = tf.reduce_sum(mlm_weights * per_example_loss) denominator = tf.reduce_sum(mlm_weights) + 1e-5 mlm_loss = numerator / denominator return { "logits": logits, "losses": mlm_loss, }
def cumulative_state_net(X, name, use_fp16, pdrop, train, pool_kernel_size=2, nominal_pool_length=512, use_fused_kernel=True): conv_kernel = 4 pool_kernel_size = pool_kernel_size or conv_kernel nx = shape_list(X)[-1] with tf.variable_scope(name): output = tf.nn.relu(normal_1d_conv_block(X, conv_kernel, "1-" + str(conv_kernel), use_fp16, output_dim=nx)) output = tf.nn.relu(normal_1d_conv_block(output, conv_kernel, "2-" + str(conv_kernel), use_fp16, output_dim=nx)) output = normal_1d_conv_block(output, conv_kernel, "3-" + str(conv_kernel), use_fp16, output_dim=nx) output = dropout(output, pdrop, train) aggregated = cascaded_pool(output, kernel_size=pool_kernel_size, pool_len=nominal_pool_length, use_fused_kernel=use_fused_kernel) return normal_1d_conv_block(aggregated, 1, "output_reproject", use_fp16, output_dim=nx)
def cascaded_pool(value, kernel_size, dim=1, pool_len=None, use_fused_kernel=True): shape = shape_list(value) full_pool_len = pool_len or shape[dim] if use_fused_kernel: ra = recursive_agg else: ra = recursive_agg_tf aggregated = ra(value, kernel_size, full_pool_len) num_pooling_ops = shape_list(aggregated)[2] ws = normal_1d_conv_block( value, 1, "pool_w", value.dtype == tf.float16, output_dim=shape[-1] ) wt = normal_1d_conv_block( value, 1, "pool_t", value.dtype == tf.float16, output_dim=num_pooling_ops ) weights = tf.nn.softmax(wt) wt = tf.expand_dims(weights, -1) weighted_over_time = tf.reduce_mean(aggregated * wt, 2) * tf.nn.sigmoid(ws) return weighted_over_time
def norm(x, scope, axis=-1, e=None, fp16=False): with tf.variable_scope(scope): e = e or 1e-5 n_state = shape_list(x)[-1] g = tf.get_variable("g", [n_state], initializer=tf.constant_initializer(1)) b = tf.get_variable("b", [n_state], initializer=tf.constant_initializer(0)) if fp16: g = tf.cast(g, tf.float16) b = tf.cast(b, tf.float16) u = tf.reduce_mean(x, axis=axis, keepdims=True) s = tf.reduce_mean(tf.square(x - u), axis=axis, keepdims=True) x = (x - u) * tf.rsqrt(s + e) x = x * g + b return x
def gather_indexes(sequence_tensor, positions): """Gathers the vectors at the specific positions over a minibatch.""" sequence_shape = shape_list(sequence_tensor) batch_size = sequence_shape[0] seq_length = sequence_shape[1] width = sequence_shape[2] flat_offsets = tf.reshape( tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1] ) flat_positions = tf.reshape(positions + flat_offsets, [-1]) flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width]) output_tensor = tf.gather(flat_sequence_tensor, flat_positions) return output_tensor
def _unmerge_beam_dim(tensor, batch_size, beam_size): """Reshapes first dimension back to [batch_size, beam_size]. Args: tensor: Tensor to reshape of shape [batch_size*beam_size, ...] batch_size: Tensor, original batch size. beam_size: int, original beam size. Returns: Reshaped tensor of shape [batch_size, beam_size, ...] """ shape = shape_list(tensor) new_shape = [batch_size] + [beam_size] + shape[1:] return tf.reshape(tensor, new_shape)
def attn_weights(q, k, v, scale=False, mask=True, explain=False): w = tf.matmul(q, k) if scale: n_state = shape_list(v)[-1] w = w * tf.rsqrt(tf.cast(n_state, tf.float32)) if mask: if explain: w = explain_mask_attn_weights(w) else: w = mask_attn_weights(w) w = tf.nn.softmax(w) return w
def enc_dec_mix(enc, dec, enc_mask, dec_mask, n_head=16): # enc = batch, seq, feats # dec = batch, seq, feats with tf.variable_scope("enc_dec_attn"): batch, dec_seq, feats = shape_list(dec) enc_seq = shape_list(enc)[1] enc_proj = normal_1d_conv_block(enc, 1, "enc_proj", use_fp16=enc.dtype == tf.float16, output_dim=feats * 2) dec_proj = normal_1d_conv_block(dec, 1, "dec_proj", use_fp16=enc.dtype == tf.float16, output_dim=feats) k, v = tf.split(enc_proj, 2, 2) q = dec_proj q = split_heads(q, n_head) k = split_heads(k, n_head, k=True) v = split_heads(v, n_head) w = tf.matmul(q, k) w = w * tf.rsqrt(cast_maybe(dec_seq, tf.float32)) enc_mask = tf.reshape(tf.sequence_mask(enc_mask, maxlen=enc_seq, dtype=enc.dtype), [batch, 1, 1, enc_seq]) dec_mask = tf.reshape(tf.sequence_mask(dec_mask, maxlen=dec_seq, dtype=enc.dtype), [batch, 1, dec_seq, 1]) m = enc_mask * dec_mask w = w * m + -1e9 * (1 - m) w = tf.nn.softmax(w) return merge_heads(tf.matmul(w, v))
def explain_mask_attn_weights(w): # w is [batch, heads, n, n] # lengths is [batch] batch, _, _, n = shape_list(w) seq = n // 2 main_mask = tf.matrix_band_part(tf.ones([seq, seq]), -1, 0) top = tf.expand_dims(tf.concat((main_mask, tf.zeros([seq, seq])), 1), 0) # 1, seq, 2 * seq clf_to_clf_mask = tf.eye(seq) bottom = tf.expand_dims(tf.concat((main_mask, clf_to_clf_mask), 1), 0) # 1, seq, 2 * seq m = tf.concat((top, bottom), 1) b = tf.reshape(m, [1, 1, n, n]) w = w * b + -1e9 * (1 - b) return w