def block(x, *, past, hparams, train=False): nx = x.shape[-1].value a = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams, train=train) x = x + a m = mlp(norm(x, 'ln_2'), 'mlp', nx * 4, hparams=hparams, train=train) x = x + m return x
def block(x, *, past, hparams, train=False): nx = x.shape[-1].value a = attn(norm(x, "ln_1"), "attn", nx, past=past, hparams=hparams, train=train) if hparams.adapter_size is not None: with tf.variable_scope("attn_adapter"): a = adapter(a, hparams.adapter_size, nx, train) x = x + a m = mlp(norm(x, "ln_2"), "mlp", nx * 4, hparams=hparams, train=train) if hparams.adapter_size is not None: with tf.variable_scope("dense_adapter"): m = adapter(m, hparams.adapter_size, nx, train) x = x + m return x
def block(x, *, past, hparams, train=False): nx = x.shape[-1].value a = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams, train=train) if hparams.adapter_size is not None: with tf.variable_scope('attn_adapter'): a = adapter(a, hparams.adapter_size, nx, train) x = x + a m = mlp(norm(x, 'ln_2'), 'mlp', nx * 4, hparams=hparams, train=train) if hparams.adapter_size is not None: with tf.variable_scope('dense_adapter'): m = adapter(m, hparams.adapter_size, nx, train) x = x + m return x
def seq_lab_internal(hidden): attn_fn = functools.partial(attn, scope="seq_label_attn", n_state=nx, n_head=config.seq_num_heads, resid_pdrop=config.resid_p_drop, attn_pdrop=config.attn_p_drop, train=train, scale=False, mask=False) n = norm(attn_fn(hidden) + hidden, 'seq_label_residual') flat_logits = tf.layers.dense(n, n_targets) logits = tf.reshape( flat_logits, tf.concat([tf.shape(hidden)[:2], [n_targets]], 0)) return logits
def seq_lab_internal(hidden): attn_fn = functools.partial( attn, scope="seq_label_attn", n_state=nx, n_head=config.seq_num_heads, resid_pdrop=config.resid_p_drop, attn_pdrop=config.attn_p_drop, train=train, scale=False, mask=False, lengths=lengths, ) n = norm(attn_fn(hidden) + hidden, "seq_label_residual") flat_logits = tf.layers.dense(n, n_targets) logits = tf.reshape( flat_logits, tf.concat([tf.shape(hidden)[:2], [n_targets]], 0) ) association_head = tf.layers.dense(n, nx) association_head = tf.reshape( association_head, tf.concat([tf.shape(hidden)[:2], [nx]], 0) ) a = tf.expand_dims(association_head, 1) b = tf.expand_dims(association_head, 2) features = tf.concat( [ a - b, a * b, tf.tile(a, [1, length, 1, 1]), tf.tile(b, [1, 1, length, 1]), # TODO: Think about using prediction as a feature for associations. ], axis=-1, ) associations_flat = tf.layers.dense( tf.reshape(features, shape=[-1, nx * 4]), num_associations ) associations = tf.reshape( associations_flat, [-1, length, length, num_associations] ) return logits, associations_flat, associations
def masked_language_model(*, X, M, mlm_weights, mlm_positions, mlm_ids, embed_weights, hidden, config, reuse=None, train=False): X = merge_leading_dims(X, 3) M = merge_leading_dims(M, 2) hidden = merge_leading_dims(hidden, 3) batch, seq, _ = shape_list(X) with tf.variable_scope('model/masked-language-model'): gathered_hidden = gather_indexes(hidden, mlm_positions) final_proj = tf.layers.dense( gathered_hidden, units=config.n_embed, activation=act_fns[config.act_fn], kernel_initializer=tf.random_normal_initializer(stddev=config.weight_stddev), name='dense' ) normed_proj = norm(final_proj, 'LayerNorm') n_vocab = shape_list(embed_weights)[0] output_bias = tf.get_variable( "output_bias", shape=[n_vocab], initializer=tf.zeros_initializer() ) logits = tf.matmul(normed_proj, embed_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) mlm_ids = tf.reshape(mlm_ids, [-1]) mlm_weights = tf.reshape(mlm_weights, [-1]) log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(mlm_ids, depth=n_vocab, dtype=tf.float32) per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) numerator = tf.reduce_sum(mlm_weights * per_example_loss) denominator = tf.reduce_sum(mlm_weights) + 1e-5 mlm_loss = numerator / denominator return { "logits": logits, "losses": mlm_loss, }
def gpt2_featurizer( X, encoder, config, train=False, reuse=None, **kwargs ): initial_shape = tf.shape(X) X = tf.reshape(X, shape=tf.concat(([-1], initial_shape[-2:]), 0)) X.set_shape([None, None, None]) with tf.variable_scope("model/featurizer", reuse=reuse): embed_weights = tf.get_variable( name="we", shape=[encoder.vocab_size + config.max_length, config.n_embed], initializer=tf.random_normal_initializer(stddev=config.weight_stddev), ) if config.train_embeddings: embed_weights = dropout(embed_weights, config.embed_p_drop, train) else: embed_weights = tf.stop_gradient(embed_weights) X = tf.reshape(X, [-1, config.max_length, 2]) h = embed(X, embed_weights) # Transformer pasts = [None] * config.n_layer for layer, past in enumerate(pasts): if ( (config.n_layer - layer) == config.num_layers_trained and config.num_layers_trained != config.n_layer and config.adapter_size is None ): h = tf.stop_gradient(h) train_layer = False else: train_layer = train with tf.variable_scope("h%d" % layer): block_fn = functools.partial( block, past=past, hparams=config, train=train ) if config.low_memory_mode and train_layer: block_fn = recompute_grad(block_fn, use_entire_scope=True) h = block_fn(h) h = norm(h, "ln_f") # Use hidden state at classifier token as input to final proj. + softmax clf_h = tf.reshape(h, [-1, config.n_embed]) # [batch * seq_len, embed] clf_token = encoder["_classify_"] pool_idx = tf.cast( tf.argmax(tf.cast(tf.equal(X[:, :, 0], clf_token), tf.float32), 1), tf.int32 ) clf_h = tf.gather( clf_h, tf.range(shape_list(X)[0], dtype=tf.int32) * config.max_length + pool_idx, ) clf_h = tf.reshape( clf_h, shape=tf.concat((initial_shape[:-2], [config.n_embed]), 0) ) seq_feats = tf.reshape( h, shape=tf.concat((initial_shape[:-1], [config.n_embed]), 0) ) lengths = lengths_from_eos_idx(eos_idx=pool_idx, max_length=shape_list(X)[0]) return { "embed_weights": embed_weights, "features": clf_h, "sequence_features": seq_feats, "eos_idx": pool_idx, "lengths": lengths }