def _get_discriminator_output(self, inputs, sample_weight, discriminator, labels): '''Discriminator binary classifier.''' with tf.variable_scope('discriminator_predictions'): hidden = tf.layers.dense( discriminator.get_sequence_output(), units=self.bert_config.hidden_size, activation=util.get_activation(self.bert_config.hidden_act), kernel_initializer=util.create_initializer( self.bert_config.initializer_range)) logits = tf.squeeze(tf.layers.dense(hidden, units=1), -1) weights = tf.cast(inputs.input_mask, tf.float32) labelsf = tf.cast(labels, tf.float32) losses = tf.nn.sigmoid_cross_entropy_with_logits( logits=logits, labels=labelsf) * weights per_example_loss = (tf.reduce_sum(losses, axis=-1) / (1e-6 + tf.reduce_sum(weights, axis=-1))) if sample_weight is not None: sample_weight = tf.cast(sample_weight, dtype=tf.float32) per_example_loss *= sample_weight loss = tf.reduce_sum(losses) / (1e-6 + tf.reduce_sum(weights)) probs = tf.nn.sigmoid(logits) preds = tf.cast(tf.greater(probs, 0.5), tf.int32) DiscOutput = collections.namedtuple( 'DiscOutput', ['loss', 'per_example_loss', 'probs', 'preds', 'labels']) return DiscOutput(loss=loss, per_example_loss=per_example_loss, probs=probs, preds=preds, labels=labels)
def __init__(self, is_training, input_tensor, input_mask, label_ids, label_size=2, sample_weight=None, scope='cls/sequence', name='', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, **kwargs): super().__init__(**kwargs) batch_size = tf.shape(input_tensor)[0] seq_length = input_tensor.shape.as_list()[-2] hidden_size = input_tensor.shape.as_list()[-1] with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[label_size, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) output_layer = tf.reshape(output_layer, [-1, hidden_size]) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape(logits, [-1, seq_length, label_size]) self.preds[name] = tf.argmax(logits, axis=-1) self.probs[name] = tf.nn.softmax(logits, axis=-1, name='probs') log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=label_size, dtype=tf.float32) per_token_losses = -tf.reduce_mean(one_hot_labels * log_probs, axis=-1) input_mask = tf.concat([ tf.zeros((batch_size, 1), dtype=tf.float32), tf.cast(input_mask[:, 2:], dtype=tf.float32), tf.zeros((batch_size, 1), dtype=tf.float32) ], axis=-1) per_token_losses *= input_mask per_example_loss = tf.reduce_mean(per_token_losses, axis=-1) if sample_weight is not None: per_example_loss *= tf.cast(sample_weight, dtype=tf.float32) self.losses[name] = per_example_loss self.total_loss = tf.reduce_mean(per_example_loss)
def _forward(dilated_ids, dilated_mask): logits = self._bert_forward( bert_config, dilated_ids, dilated_mask, batch_size, dilated_seq_length, tilda_embeddings=tilda_embeddings) output_ids = tf.argmax(logits, axis=-1) output_ids = tf.cast(output_ids, dtype=tf.int32) # special padding (using `spad` token) equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32) equal_zero = tf.reduce_sum(equal_zero, axis=-1) right_pad = spad_id * tf.sequence_mask( equal_zero, dilated_seq_length, dtype=tf.int32) paded = tf.concat([output_ids, right_pad], axis=-1) # extract ids of length `max_seq_length` flattened_padded = tf.reshape(paded, [-1]) is_valid = tf.cast(tf.greater(flattened_padded, 0), dtype=tf.int32) flattened_valid = tf.boolean_mask(flattened_padded, is_valid) valid = tf.reshape(flattened_valid, [batch_size, dilated_seq_length]) cutted_valid = valid[:, :max_seq_length] # replace `spad` token with `pad` non_spad_mask = tf.cast(tf.not_equal( cutted_valid, spad_id), dtype=tf.int32) output_ids = cutted_valid * non_spad_mask output_length = tf.reduce_sum(non_spad_mask, axis=-1) # dilate reshaped_ids = tf.reshape(output_ids, [batch_size, max_seq_length, 1]) reshaped_mask = tf.reshape( tf.sequence_mask(output_length, max_seq_length, dtype=tf.int32), [batch_size, max_seq_length, 1]) concat_ids = tf.concat( [reshaped_ids, tf.zeros_like(reshaped_ids)], axis=-1) concat_mask = tf.concat([ reshaped_mask, tf.zeros_like(reshaped_mask, dtype=tf.int32) ], axis=-1) dilated_ids = tf.reshape(concat_ids, [batch_size, max_seq_length * 2]) dilated_mask = tf.reshape(concat_mask, [batch_size, max_seq_length * 2]) return dilated_ids, dilated_mask
def scatter_update(sequence, updates, positions): '''Scatter-update a sequence. Args: sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor updates: A tensor of size batch_size*seq_len(*depth) positions: A [batch_size, n_positions] tensor Returns: A tuple of two tensors. First is a [batch_size, seq_len] or [batch_size, seq_len, depth] tensor of 'sequence' with elements at 'positions' replaced by the values at 'updates.' Updates to index 0 are ignored. If there are duplicated positions the update is only applied once. Second is a [batch_size, seq_len] mask tensor of which inputs were updated. ''' shape = util.get_shape_list(sequence, expected_rank=[2, 3]) depth_dimension = (len(shape) == 3) if depth_dimension: B, L, D = shape else: B, L = shape D = 1 sequence = tf.expand_dims(sequence, -1) N = util.get_shape_list(positions)[1] shift = tf.expand_dims(L * tf.range(B), -1) flat_positions = tf.reshape(positions + shift, [-1, 1]) flat_updates = tf.reshape(updates, [-1, D]) updates = tf.scatter_nd(flat_positions, flat_updates, [B * L, D]) updates = tf.reshape(updates, [B, L, D]) flat_updates_mask = tf.ones([B * N], tf.int32) updates_mask = tf.scatter_nd(flat_positions, flat_updates_mask, [B * L]) updates_mask = tf.reshape(updates_mask, [B, L]) not_first_token = tf.concat( [tf.zeros((B, 1), tf.int32), tf.ones((B, L - 1), tf.int32)], -1) updates_mask *= not_first_token updates_mask_3d = tf.expand_dims(updates_mask, -1) # account for duplicate positions if sequence.dtype == tf.float32: updates_mask_3d = tf.cast(updates_mask_3d, tf.float32) updates /= tf.maximum(1.0, updates_mask_3d) else: assert sequence.dtype == tf.int32 updates = tf.divide(updates, tf.maximum(1, updates_mask_3d)) updates = tf.cast(updates, tf.int32) updates_mask = tf.minimum(updates_mask, 1) updates_mask_3d = tf.minimum(updates_mask_3d, 1) updated_sequence = (((1 - updates_mask_3d) * sequence) + (updates_mask_3d * updates)) if not depth_dimension: updated_sequence = tf.squeeze(updated_sequence, -1) return updated_sequence, updates_mask
def __init__(self, is_training, input_tensor, label_ids, label_size=2, sample_weight=None, scope='cls/seq_relationship', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, **kwargs): super().__init__(**kwargs) hidden_size = input_tensor.shape.as_list()[-1] with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[label_size, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable( 'output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) self.preds['preds'] = tf.argmax(logits, axis=-1) self.probs['probs'] = tf.nn.softmax(logits, axis=-1, name='probs') log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot( label_ids, depth=label_size, dtype=tf.float32) per_example_loss = - tf.reduce_sum( one_hot_labels * log_probs, axis=-1) if sample_weight is not None: per_example_loss = tf.cast( sample_weight, dtype=tf.float32) * per_example_loss thresh = kwargs.get('tsa_thresh') if thresh is not None: assert isinstance(thresh, float), ( '`tsa_thresh` must be a float between 0 and 1.') uncertainty = tf.reduce_sum(self.probs['probs'] * tf.log( self.probs['probs']), axis=-1) uncertainty /= tf.log(1 / label_size) per_example_loss = tf.cast( tf.greater(uncertainty, thresh), dtype=tf.float32) * \ per_example_loss self.losses['losses'] = per_example_loss self.total_loss = tf.reduce_mean(per_example_loss)
def relative_positional_encoding(qlen, klen, d_model, clamp_len, attn_type, bi_data, bsz=None, dtype=None): '''create relative positional encoding.''' freq_seq = tf.range(0, d_model, 2.0) if dtype is not None and dtype != tf.float32: freq_seq = tf.cast(freq_seq, dtype=dtype) inv_freq = 1 / (10000**(freq_seq / d_model)) if attn_type == 'bi': # beg, end = klen - 1, -qlen beg, end = klen, -qlen elif attn_type == 'uni': # beg, end = klen - 1, -1 beg, end = klen, -1 else: raise ValueError('Unknown `attn_type` {}.'.format(attn_type)) if bi_data: fwd_pos_seq = tf.range(beg, end, -1.0) bwd_pos_seq = tf.range(-beg, -end, 1.0) if dtype is not None and dtype != tf.float32: fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype) bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype) if clamp_len > 0: fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len) bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -clamp_len, clamp_len) if bsz is not None: # With bi_data, the batch size should be divisible by 2. assert bsz % 2 == 0 fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz // 2) bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq, bsz // 2) else: fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq) bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq) pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1) else: fwd_pos_seq = tf.range(beg, end, -1.0) if dtype is not None and dtype != tf.float32: fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype) if clamp_len > 0: fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len) pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz) return pos_emb
def crf_unary_score(tag_indices, sequence_lengths, inputs): ''' Computes the unary scores of tag sequences. Args: tag_indices: A [batch_size, max_seq_len] matrix of tag indices. sequence_lengths: A [batch_size] vector of true sequence lengths. inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials. Returns: unary_scores: A [batch_size] vector of unary scores. ''' batch_size = tf.shape(inputs)[0] max_seq_len = tf.shape(inputs)[1] num_tags = tf.shape(inputs)[2] flattened_inputs = tf.reshape(inputs, [-1]) offsets = tf.expand_dims(tf.range(batch_size) * max_seq_len * num_tags, 1) offsets += tf.expand_dims(tf.range(max_seq_len) * num_tags, 0) # Use int32 or int64 based on tag_indices' dtype. if tag_indices.dtype == tf.int64: offsets = tf.cast(offsets, tf.int64) flattened_tag_indices = tf.reshape(offsets + tag_indices, [-1]) unary_scores = tf.reshape( tf.gather(flattened_inputs, flattened_tag_indices), [batch_size, max_seq_len]) masks = tf.sequence_mask(sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=tf.float32) unary_scores = tf.reduce_sum(unary_scores * masks, 1) return unary_scores
def create_attention_mask_from_input_mask(from_tensor, to_mask): '''Create 3D attention mask from a 2D tensor mask. Args: from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. to_mask: int32 Tensor of shape [batch_size, to_seq_length]. Returns: float Tensor of shape [batch_size, from_seq_length, to_seq_length]. ''' from_shape = util.get_shape_list(from_tensor, expected_rank=[2, 3]) batch_size = from_shape[0] from_seq_length = from_shape[1] to_shape = util.get_shape_list(to_mask, expected_rank=2) to_seq_length = to_shape[1] to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32) # We don't assume that `from_tensor` is a mask (although it could be). We # don't actually care if we attend *from* padding tokens (only *to* padding) # tokens so we create a tensor of all ones. # # `broadcast_ones` = [batch_size, from_seq_length, 1] broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1], dtype=tf.float32) # Here we broadcast along two dimensions to create the mask. mask = broadcast_ones * to_mask return mask
def __init__(self, xlnet_config, is_training, input_ids, seg_ids, input_mask, mems, perm_mask, target, target_mask, target_mapping, inp_q, sample_weight=None, **kwargs): super().__init__() run_config = XLNetRunConfig(is_training=is_training, bi_data=True, use_tpu=False, use_bfloat16=False, dropout=(0.1 if is_training else 0.0), dropatt=(0.1 if is_training else 0.0), init='normal', init_range=0.1, init_std=0.02, clamp_len=-1) model = XLNetEncoder(xlnet_config=xlnet_config, is_training=is_training, input_ids=input_ids, seg_ids=seg_ids, input_mask=input_mask, mems=mems, perm_mask=perm_mask, target_mapping=target_mapping, inp_q=inp_q, **kwargs) with tf.variable_scope('model', reuse=tf.AUTO_REUSE): per_example_loss, preds = lm_loss( hidden=model.get_sequence_output(), target=target, n_token=xlnet_config.n_token, d_model=xlnet_config.d_model, initializer=model.get_initializer(), lookup_table=model.get_embedding_table(), tie_weight=True, bi_data=run_config.bi_data, use_tpu=run_config.use_tpu) if sample_weight is not None: sample_weight = tf.expand_dims(tf.cast(sample_weight, dtype=tf.float32), axis=-1) per_example_loss *= sample_weight self.total_loss = tf.reduce_sum( per_example_loss * target_mask) / tf.reduce_sum(target_mask) self.losses['losses'] = per_example_loss * target_mask self.preds['preds'] = preds self.preds['mask'] = target_mask
def relu_kernel_transformation(data, is_query, projection_matrix=None, numerical_stabilizer=0.001): '''Computes features for the ReLU-kernel. Computes random features for the ReLU kernel from https://arxiv.org/pdf/2009.14794.pdf. Args: data: input data tensor of the shape [B, L, H, D], where: B - batch dimension, L - attention dimensions, H - heads, D - features. is_query: indicates whether input data is a query oor key tensor. projection_matrix: random Gaussian matrix of shape [M, D], where M stands for the number of random features and each D x D sub-block has pairwise orthogonal rows. numerical_stabilizer: small positive constant for numerical stability. Returns: Corresponding kernel feature map. ''' del is_query if projection_matrix is None: return tf.nn.relu(data) + numerical_stabilizer else: ratio = tf.math.rsqrt(tf.cast(projection_matrix.shape[0], tf.float32)) data_dash = ratio * tf.einsum('blhd,md->blhm', data, projection_matrix) return tf.nn.relu(data_dash) + numerical_stabilizer
def create_products_of_givens_rotations(dim, seed): r'''Constructs a 2D-tensor which is a product of Givens random rotations. Constructs a 2D-tensor of the form G_1 * ... * G_k, where G_i is a Givens random rotation. The resulting tensor mimics a matrix taken uniformly at random form the orthogonal group. Args: dim: number of rows/columns of the resulting 2D-tensor. seed: random seed. Returns: The product of Givens random rotations. ''' nb_givens_rotations = dim * int(math.ceil(math.log(float(dim)))) q = np.eye(dim, dim) np.random.seed(seed) for _ in range(nb_givens_rotations): random_angle = math.pi * np.random.uniform() random_indices = np.random.choice(dim, 2) index_i = min(random_indices[0], random_indices[1]) index_j = max(random_indices[0], random_indices[1]) slice_i = q[index_i] slice_j = q[index_j] new_slice_i = math.cos(random_angle) * slice_i + math.sin( random_angle) * slice_j new_slice_j = -math.sin(random_angle) * slice_i + math.cos( random_angle) * slice_j q[index_i] = new_slice_i q[index_j] = new_slice_j return tf.cast(tf.constant(q), dtype=tf.float32)
def mask_attn_weights(w): # w has shape [batch, heads, dst_sequence, src_sequence], where # information flows from src to dst. _, _, nd, ns = shape_list(w) b = attention_mask(nd, ns, dtype=w.dtype) b = tf.reshape(b, [1, 1, nd, ns]) w = w * b - tf.cast(1e10, w.dtype) * (1 - b) return w
def multihead_attn(q, k, v): # q, k, v have shape [batch, heads, sequence, features] w = tf.matmul(q, k, transpose_b=True) w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype)) w = mask_attn_weights(w) w = softmax(w) a = tf.matmul(w, v) return a
def _forward(dilated_ids, dilated_mask): logits = self._bert_forward( bert_config, dilated_ids, dilated_mask, batch_size, dilated_seq_length, tilda_embeddings=tilda_embeddings) output_ids = tf.argmax(logits, axis=-1) output_ids = tf.cast(output_ids, dtype=tf.int32) equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32) equal_zero = tf.reduce_sum(equal_zero, axis=-1) right_pad = spad_id * tf.sequence_mask( equal_zero, dilated_seq_length, dtype=tf.int32) paded = tf.concat([output_ids, right_pad], axis=-1) flattened_padded = tf.reshape(paded, [-1]) is_valid = tf.cast(tf.greater(flattened_padded, 0), dtype=tf.int32) flattened_valid = tf.boolean_mask(flattened_padded, is_valid) valid = tf.reshape(flattened_valid, [batch_size, dilated_seq_length]) cutted_valid = valid[:, :max_seq_length] nonpad_mask = tf.cast(tf.not_equal(cutted_valid, spad_id), dtype=tf.int32) output_ids = cutted_valid * nonpad_mask reshaped = tf.reshape(output_ids, [batch_size, max_seq_length, 1]) concatenated = tf.concat( [reshaped, tf.zeros_like(reshaped)], axis=-1) dilated_ids = tf.reshape(concatenated, [batch_size, max_seq_length * 2]) input_mask = tf.reduce_sum(nonpad_mask, axis=-1) dilated_mask = tf.sequence_mask(input_mask, dilated_seq_length, dtype=tf.int32) return dilated_ids, dilated_mask
def create_attention_mask_from_input_mask(self, input_mask, batch_size, max_seq_length, dtype=tf.float32): to_mask = tf.cast(tf.reshape(input_mask, [batch_size, 1, max_seq_length]), dtype=dtype) broadcast_ones = tf.ones(shape=[batch_size, max_seq_length, 1], dtype=dtype) mask = broadcast_ones * to_mask broadcast_eye = tf.tile( tf.reshape(tf.eye(max_seq_length), [1, max_seq_length, max_seq_length]), [batch_size, 1, 1]) mask += broadcast_eye mask = tf.cast(tf.greater(mask, 0), dtype) return mask
def attention_mask(nd, ns, *, dtype): '''1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs. ''' i = tf.range(nd)[:, None] j = tf.range(ns) m = i >= j - ns + nd return tf.cast(m, dtype)
def create_attention_mask_from_input_mask(to_mask, batch_size, max_seq_length, dtype=tf.float32): to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, max_seq_length]), dtype=dtype) broadcast_ones = tf.ones(shape=[batch_size, max_seq_length, 1], dtype=dtype) mask = broadcast_ones * to_mask return mask
def softmax_kernel_transformation(data, is_query, projection_matrix=None, numerical_stabilizer=0.000001): '''Computes random features for the softmax kernel using FAVOR+ mechanism. Computes random features for the softmax kernel using FAVOR+ mechanism from https://arxiv.org/pdf/2009.14794.pdf. Args: data: input data tensor of the shape [B, L, H, D], where: B - batch dimension, L - attention dimensions, H - heads, D - features. is_query: indicates whether input data is a query oor key tensor. projection_matrix: random Gaussian matrix of shape [M, D], where M stands for the number of random features and each D x D sub-block has pairwise orthogonal rows. numerical_stabilizer: small positive constant for numerical stability. Returns: Corresponding kernel feature map. ''' data_normalizer = \ tf.math.rsqrt(1 / tf.math.rsqrt(tf.cast(data.shape[-1], tf.float32))) ratio = tf.math.rsqrt( tf.cast( projection_matrix.shape[0] if projection_matrix is not None else 1.0, tf.float32)) data_dash = tf.einsum('blhd,md->blhm', data, projection_matrix) diag_data = tf.math.square(data) diag_data = tf.math.reduce_sum(diag_data, axis=tf.keras.backend.ndim(data) - 1) diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer diag_data = tf.expand_dims(diag_data, axis=tf.keras.backend.ndim(data) - 1) if is_query: last_dims_t = (len(data_dash.shape) - 1, ) data_dash = ratio * (tf.math.exp( data_dash - diag_data - tf.math.reduce_max(data_dash, axis=last_dims_t, keepdims=True)) + numerical_stabilizer) else: data_dash = ratio * (tf.math.exp(data_dash - diag_data - tf.math.reduce_max(data_dash)) + numerical_stabilizer) return data_dash
def __init__(self, is_training, input_tensor, input_mask, label_ids, label_size=5, sample_weight=None, scope='cls/sequence', name='', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, **kwargs): super().__init__(**kwargs) seq_length = input_tensor.shape.as_list()[-2] hidden_size = input_tensor.shape.as_list()[-1] with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[label_size, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) output_layer = tf.reshape(output_layer, [-1, hidden_size]) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape(logits, [-1, seq_length, label_size]) with tf.variable_scope('crf'): input_length = tf.reduce_sum(input_mask, axis=-1) per_example_loss, transition_matrix = \ contrib.crf.crf_log_likelihood( inputs=logits, tag_indices=label_ids, sequence_lengths=input_length) per_example_loss = -per_example_loss if sample_weight is not None: per_example_loss *= tf.cast(sample_weight, dtype=tf.float32) self.total_loss = tf.reduce_mean(per_example_loss) self.losses[name] = per_example_loss self.preds[name] = tf.argmax(logits, axis=-1) self.probs['logits'] = logits self.probs['transition_matrix'] = transition_matrix
def create_attention_mask_from_input_mask(self, input_mask, batch_size, max_seq_length, dtype=tf.float32): if self._mode == 'bi': to_mask = tf.cast(tf.reshape( input_mask, [batch_size, 1, max_seq_length]), dtype=dtype) broadcast_ones = tf.ones( shape=[batch_size, max_seq_length, 1], dtype=dtype) mask = broadcast_ones * to_mask elif self._mode == 'l2r': arange = tf.range(max_seq_length) + 1 to_mask = tf.cast(tf.sequence_mask(arange, max_seq_length), dtype) to_mask = tf.reshape(to_mask, [1, max_seq_length, max_seq_length]) mask = tf.tile(to_mask, [batch_size, 1, 1]) elif self._mode == 'r2l': to_mask = tf.cast(tf.reshape( input_mask, [batch_size, 1, max_seq_length]), dtype=dtype) broadcast_ones = tf.ones( shape=[batch_size, max_seq_length, 1], dtype=dtype) cover_mask = broadcast_ones * to_mask arange = tf.range(max_seq_length) reverse = tf.cast(tf.sequence_mask(arange, max_seq_length), dtype) reverse = tf.reshape(reverse, [1, max_seq_length, max_seq_length]) reverse_mask = tf.tile(reverse, [batch_size, 1, 1]) mask = (1 - reverse_mask) * cover_mask elif self._mode == 's2s': mask = tf.cast( tf.sequence_mask(input_mask, max_seq_length), dtype) return mask
def __init__(self, is_training, input_tensor, label_ids, label_size=2, sample_weight=None, label_weight=None, scope='cls/seq_relationship', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, **kwargs): super().__init__(**kwargs) hidden_size = input_tensor.shape.as_list()[-1] with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[label_size, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable( 'output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) probs = tf.nn.sigmoid(logits, name='probs') self.probs['probs'] = probs self.preds['preds'] = tf.greater(probs, 0.5) per_example_loss = tf.nn.sigmoid_cross_entropy_with_logits( logits=logits, labels=tf.cast(label_ids, dtype=tf.float32)) if label_weight is not None: label_weight = tf.constant(label_weight, dtype=tf.float32) label_weight = tf.reshape(label_weight, [1, label_size]) per_example_loss *= label_weight per_example_loss = tf.reduce_mean(per_example_loss, axis=-1) if sample_weight is not None: per_example_loss *= sample_weight self.losses['losses'] = per_example_loss self.total_loss = tf.reduce_mean(per_example_loss)
def dot_product_attention(q, k, v, bias, dropout_rate=0.0): """Dot-product attention. Args: q: Tensor with shape [..., length_q, depth_k]. k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must match with q. v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must match with q. bias: bias Tensor (see attention_bias()) dropout_rate: a float. Returns: Tensor with shape [..., length_q, depth_v]. """ logits = tf.matmul(q, k, transpose_b=True) # [..., length_q, length_kv] logits = tf.multiply(logits, 1.0 / math.sqrt(float(util.get_shape_list(q)[-1]))) if bias is not None: # `attention_mask` = [B, T] from_shape = util.get_shape_list(q) if len(from_shape) == 4: broadcast_ones = tf.ones([from_shape[0], 1, from_shape[2], 1], tf.float32) elif len(from_shape) == 5: # from_shape = [B, N, Block_num, block_size, depth]# broadcast_ones = tf.ones( [from_shape[0], 1, from_shape[2], from_shape[3], 1], tf.float32) bias = tf.matmul(broadcast_ones, tf.cast(bias, tf.float32), transpose_b=True) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. adder = (1.0 - bias) * -10000.0 # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. logits += adder else: adder = 0.0 attention_probs = tf.nn.softmax(logits, name="attention_probs") attention_probs = util.dropout(attention_probs, dropout_rate) return tf.matmul(attention_probs, v)
def __init__(self, is_training, input_tensor, label_ids, label_size=2, sample_weight=None, scope='cls/seq_relationship', name='', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, **kwargs): super().__init__(**kwargs) hidden_size = input_tensor.shape.as_list()[-1] with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[label_size, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) self.preds[name] = tf.argmax(logits, axis=-1) self.probs[name] = tf.nn.softmax(logits, axis=-1, name='probs') log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=label_size, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) if sample_weight is not None: per_example_loss = tf.cast(sample_weight, dtype=tf.float32) * per_example_loss self.losses[name] = per_example_loss self.total_loss = tf.reduce_mean(per_example_loss)
def _get_fake_data(self, inputs, mlm_logits): '''Sample from the generator to create corrupted input.''' inputs = unmask(inputs) disallow = tf.one_hot( inputs.masked_lm_ids, depth=self.bert_config.vocab_size, dtype=tf.float32) if self.config.disallow_correct else None sampled_tokens = tf.stop_gradient(sample_from_softmax( mlm_logits / self.config.temperature, disallow=disallow)) sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32) updated_input_ids, masked = scatter_update( inputs.input_ids, sampled_tokids, inputs.masked_lm_positions) labels = masked * (1 - tf.cast( tf.equal(updated_input_ids, inputs.input_ids), tf.int32)) updated_inputs = get_updated_inputs( inputs, input_ids=updated_input_ids) FakedData = collections.namedtuple('FakedData', [ 'inputs', 'is_fake_tokens', 'sampled_tokens']) return FakedData(inputs=updated_inputs, is_fake_tokens=labels, sampled_tokens=sampled_tokens)
def _cls_forward(self, is_training, input_tensor, input_mask, label_ids, bert_config, batch_size, max_seq_length, prob, scope, name, sample_weight=None, hidden_dropout_prob=0.1, initializer_range=0.02): with tf.variable_scope(scope): logits = tf.layers.dense( input_tensor, 2, kernel_initializer=util.create_initializer( bert_config.initializer_range), trainable=True) # loss log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=2) per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) input_mask = tf.cast(input_mask, tf.float32) per_token_loss *= input_mask / tf.reduce_sum( input_mask, keepdims=True, axis=-1) per_example_loss = tf.reduce_sum(per_token_loss, axis=-1) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) if prob != 0: self.total_loss += tf.reduce_mean(per_example_loss) self.losses[name + '_loss'] = per_example_loss self.preds[name + '_preds'] = tf.argmax(logits, axis=-1)
def __init__(self, hparams, is_training, input_ids, sample_weight=None, scope='model', given=1, use_tilda_embedding=False, **kwargs): super().__init__() batch_size = util.get_shape_list(input_ids, expected_rank=2)[0] max_seq_length = hparams.n_predict # Tilda embeddings for SMART algorithm tilda_embeddings = None if use_tilda_embedding: with tf.variable_scope('', reuse=True): tilda_embeddings = tf.get_variable('tilda_embeddings') with tf.variable_scope(scope): def _forward(input_ids, past=None): batch, sequence = shape_list(input_ids) if tilda_embeddings is None: wte = tf.get_variable( 'word_embeddings', [hparams.n_vocab, hparams.n_embed], initializer=tf.random_normal_initializer(stddev=0.02)) else: wte = tilda_embeddings wpe = tf.get_variable( 'wpe', [hparams.n_ctx, hparams.n_embed], initializer=tf.random_normal_initializer(stddev=0.01)) past_length = 0 if past is None else tf.shape(past)[-2] h = (tf.gather(wte, input_ids) + tf.gather(wpe, positions_for(input_ids, past_length))) # stacked transformer layers presents = [] pasts = tf.unstack(past, axis=1) if past is not None else \ [None] * hparams.n_layer assert len(pasts) == hparams.n_layer for layer, past in enumerate(pasts): h, present = block(h, 'h%d' % layer, past=past, hparams=hparams) presents.append(present) present = tf.stack(presents, axis=1) h = norm(h, 'ln_f') # Language model loss. Do tokens <n predict token n? h_flat = tf.reshape(h, [batch * sequence, hparams.n_embed]) logits = tf.matmul(h_flat, wte, transpose_b=True) logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab]) return logits, present # convert to labels label_ids = tf.concat( [input_ids[:, 1:], tf.zeros([batch_size, 1], dtype=tf.int32)], axis=-1) # forward once if is_training: (logits, _) = _forward(input_ids) self.preds['LM'] = tf.argmax(logits, axis=-1) # forward loop else: input_ids = input_ids[:, 0:given] for cur_length in range(given, max_seq_length + 1): (logits, _) = _forward(input_ids) pred_ids = tf.argmax(logits[:, cur_length - 1:cur_length, :], axis=-1) pred_ids = tf.cast(pred_ids, tf.int32) input_ids = tf.concat([input_ids, pred_ids], axis=-1) self.preds['LM'] = input_ids # loss log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=hparams.n_vocab) per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) label_mask = tf.cast(tf.not_equal(label_ids, 0), tf.float32) per_example_loss = \ tf.reduce_sum(per_token_loss * label_mask, axis=-1) / \ tf.reduce_sum(label_mask, axis=-1) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) self.total_loss = tf.reduce_mean(per_example_loss) self.losses['LM'] = per_example_loss
def _local_perm(inputs, targets, is_masked, perm_size, seq_len): ''' Sample a permutation of the factorization order, and create an attention mask accordingly. Args: inputs: int64 Tensor in shape [seq_len], input ids. targets: int64 Tensor in shape [seq_len], target ids. is_masked: bool Tensor in shape [seq_len]. True means being selected for partial prediction. perm_size: the length of longest permutation. Could be set to be reuse_len. Should not be larger than reuse_len or there will be data leaks. seq_len: int, sequence length. ''' batch_size = tf.shape(inputs)[0] # Generate permutation indices index = tf.range(seq_len, dtype=tf.int64) index = tf.reshape(index, [-1, perm_size]) index = tf.transpose(index) index = tf.random_shuffle(index) index = tf.transpose(index) index = tf.reshape(index, [1, -1]) index = tf.tile(index, [batch_size, 1]) # `perm_mask` and `target_mask` # non-functional tokens non_func_tokens = tf.logical_not( tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID))) non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens) masked_or_func_tokens = tf.logical_not(non_mask_tokens) # Set the permutation indices of non-masked (& non-funcional) tokens to the # smallest index (-1): # (1) they can be seen by all other positions # (2) they cannot see masked positions, so there won't be information leak smallest_index = -tf.ones([batch_size, seq_len], dtype=tf.int64) rev_index = tf.where(non_mask_tokens, smallest_index, index) # Create `target_mask`: non-funcional and maksed tokens # 1: use mask as input and have loss # 0: use token (or [SEP], [CLS]) as input and do not have loss target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens) target_mask = tf.cast(target_tokens, tf.float32) # Create `perm_mask` # `target_tokens` cannot see themselves self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1) # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens) # 0: can attend if i > j or j is non-masked perm_mask = tf.logical_and( self_rev_index[:, :, None] <= rev_index[:, None, :], tf.expand_dims(masked_or_func_tokens, axis=-1)) # new target: [next token] for LM and [curr token] (self) for PLM new_targets = tf.concat([inputs[:, 0:1], targets[:, :-1]], axis=1) # construct inputs_k inputs_k = inputs # construct inputs_q inputs_q = target_mask return perm_mask, new_targets, target_mask, inputs_k, inputs_q
def attention_layer(self, from_tensor, to_tensor, attention_mask=None, num_attention_heads=12, size_per_head=512, query_act=None, key_act=None, value_act=None, attention_probs_dropout_prob=0.0, initializer_range=0.02, do_return_2d_tensor=False, batch_size=None, from_max_seq_length=None, to_max_seq_length=None, dtype=tf.float32, trainable=True): def transpose_for_scores(input_tensor, batch_size, num_attention_heads, max_seq_length, width): output_tensor = tf.reshape( input_tensor, [batch_size, max_seq_length, num_attention_heads, width]) output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) return output_tensor # Scalar dimensions referenced here: # B = batch size (number of sequences) # F = from_tensor sequence length # T = to_tensor sequence length # N = num_attention_heads # H = size_per_head from_tensor_2d = util.reshape_to_matrix(from_tensor) to_tensor_2d = util.reshape_to_matrix(to_tensor) # query_layer = [B*F, N*H] query_layer = tf.layers.dense( from_tensor_2d, num_attention_heads * size_per_head, activation=query_act, name='query', kernel_initializer=util.create_initializer(initializer_range), trainable=trainable) # key_layer = [B*T, N*H] key_layer = tf.layers.dense( to_tensor_2d, num_attention_heads * size_per_head, activation=key_act, name='key', kernel_initializer=util.create_initializer(initializer_range), trainable=trainable) # value_layer = [B*T, N*H] value_layer = tf.layers.dense( to_tensor_2d, num_attention_heads * size_per_head, activation=value_act, name='value', kernel_initializer=util.create_initializer(initializer_range), trainable=trainable) # query_layer = [B, N, F, H] query_layer = transpose_for_scores(query_layer, batch_size, num_attention_heads, from_max_seq_length, size_per_head) # key_layer = [B, N, T, H] key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, to_max_seq_length, size_per_head) # Take the dot product between 'query' and 'key' to get the raw # attention scores. # attention_scores = [B, N, F, T] attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) attention_scores = tf.multiply(attention_scores, 1.0 / math.sqrt(float(size_per_head))) if attention_mask is not None: # attention_mask = [B, 1, F, T] attention_mask = tf.expand_dims(attention_mask, axis=[1]) adder = (1.0 - tf.cast(attention_mask, dtype)) * -10000.0 attention_scores += adder # Normalize the attention scores to probabilities. # attention_probs = [B, N, F, T] attention_probs = tf.nn.softmax(attention_scores, axis=-1) # This is actually dropping out entire tokens to attend to, # which might seem a bit unusual, but is taken from the original # Transformer paper. attention_probs = util.dropout(attention_probs, attention_probs_dropout_prob) # value_layer = [B, T, N, H] value_layer = tf.reshape(value_layer, [ batch_size, to_max_seq_length, num_attention_heads, size_per_head ]) # value_layer = [B, N, T, H] value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) # context_layer = [B, N, F, H] context_layer = tf.matmul(attention_probs, value_layer) # context_layer = [B, F, N, H] context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) if do_return_2d_tensor: # context_layer = [B*F, N*H] context_layer = tf.reshape(context_layer, [ batch_size * from_max_seq_length, num_attention_heads * size_per_head ]) else: # context_layer = [B, F, N*H] context_layer = tf.reshape(context_layer, [ batch_size, from_max_seq_length, num_attention_heads * size_per_head ]) return (context_layer, attention_scores)
def __init__(self, bert_config, is_training, sketchy_encoder, intensive_encoder, query_mask, label_ids, has_answer, sample_weight=None, scope='retro_reader', matching_mechanism='cross-attention', beta_1=0.5, beta_2=0.5, threshold=1.0, trainable=True, **kwargs): super().__init__(**kwargs) # verifier with tf.variable_scope(scope): # sketchy reading module with tf.variable_scope('sketchy/prediction'): sketchy_output = sketchy_encoder.get_pooled_output() hidden_size = sketchy_output.shape.as_list()[-1] output_weights = tf.get_variable( 'output_weights', shape=[2, hidden_size], initializer=util.create_initializer( bert_config.initializer_range), trainable=trainable) output_bias = tf.get_variable( 'output_bias', shape=[2], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( sketchy_output, bert_config.hidden_dropout_prob \ if is_training else 0.0) logits = tf.matmul( output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot( has_answer, depth=2, dtype=tf.float32) per_example_loss = - tf.reduce_sum( one_hot_labels * log_probs, axis=-1) if sample_weight is not None: per_example_loss = tf.cast( sample_weight, dtype=tf.float32) * per_example_loss self.losses['sketchy_losses'] = per_example_loss sketchy_loss = tf.reduce_mean(per_example_loss) score_ext = logits[:, 1] - logits[:, 0] # intensive reading module with tf.variable_scope('intensive'): H = intensive_encoder.get_sequence_output() H_Q = H * tf.cast( tf.expand_dims(query_mask, axis=-1), tf.float32) (batch_size, max_seq_length, hidden_size) = \ util.get_shape_list(H) # cross-attention if matching_mechanism == 'cross-attention': with tf.variable_scope('cross_attention'): attention_mask = \ self.create_attention_mask_from_input_mask( query_mask, batch_size, max_seq_length) (H_prime, _) = self.attention_layer( from_tensor=H, to_tensor=H_Q, attention_mask=attention_mask, num_attention_heads=\ bert_config.num_attention_heads, size_per_head=\ hidden_size // bert_config.num_attention_heads, attention_probs_dropout_prob=\ bert_config.hidden_dropout_prob, initializer_range=bert_config.initializer_range, do_return_2d_tensor=False, batch_size=batch_size, from_max_seq_length=max_seq_length, to_max_seq_length=max_seq_length, trainable=trainable) # matching-attention elif matching_mechanism == 'matching-attention': with tf.variable_scope('matching_attention'): output_weights = tf.get_variable( 'output_weights', shape=[hidden_size, hidden_size], initializer=util.create_initializer( bert_config.initializer_range), trainable=trainable) output_bias = tf.get_variable( 'output_bias', shape=[hidden_size], initializer=tf.zeros_initializer(), trainable=trainable) trans = tf.matmul( H_Q, tf.tile( tf.expand_dims(output_weights, axis=0), [batch_size, 1, 1]), transpose_b=True) trans = tf.nn.bias_add(trans, output_bias) M = tf.nn.softmax( tf.matmul(H, trans, transpose_b=True), axis=-1) H_prime = tf.matmul(M, H_Q) with tf.variable_scope('prediction'): output_weights = tf.get_variable( 'output_weights', shape=[2, hidden_size], initializer=util.create_initializer( bert_config.initializer_range), trainable=trainable) output_bias = tf.get_variable( 'output_bias', shape=[2], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( H_prime, bert_config.hidden_dropout_prob \ if is_training else 0.0) output_layer = tf.reshape( output_layer, [batch_size * max_seq_length, hidden_size]) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape( logits, [batch_size, max_seq_length, 2]) logits = tf.transpose(logits, [0, 2, 1]) probs = tf.nn.softmax(logits, axis=-1, name='probs') self.probs['mrc_probs'] = probs self.preds['mrc_preds'] = tf.argmax(logits, axis=-1) start_one_hot_labels = tf.one_hot( label_ids[:, 0], depth=max_seq_length, dtype=tf.float32) end_one_hot_labels = tf.one_hot( label_ids[:, 1], depth=max_seq_length, dtype=tf.float32) start_log_probs = tf.nn.log_softmax(logits[:, 0, :], axis=-1) end_log_probs = tf.nn.log_softmax(logits[:, 1, :], axis=-1) per_example_loss = ( - 0.5 * tf.reduce_sum( start_one_hot_labels * start_log_probs, axis=-1) - 0.5 * tf.reduce_sum( end_one_hot_labels * end_log_probs, axis=-1)) if sample_weight is not None: per_example_loss *= sample_weight intensive_loss = tf.reduce_mean(per_example_loss) self.losses['intensive_losses'] = per_example_loss score_has = tf.norm( probs[:, 0, 1:] + probs[:, 1, 1:], np.inf, axis=-1) score_null = probs[:, 0, 0] + probs[:, 1, 0] score_diff = score_has - score_null # rear verification v = beta_1 * score_diff + beta_2 * score_ext self.preds['verifier_preds'] = \ tf.cast(tf.greater(v, threshold), tf.int32) self.probs['verifier_probs'] = v self.total_loss = sketchy_loss + intensive_loss
def __init__(self, vocab_size, is_training, input_ids, input_mask, segment_ids, sample_weight=None, reduced_size=64, topic_size=1024, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, bias=0, scope='vae', trainable=True, **kwargs): super().__init__() # freeze parameters config = Config(vocab_size, hidden_size=hidden_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads) if not is_training: config.hidden_dropout_prob = 0.0 config.attention_probs_dropout_prob = 0.0 input_shape = util.get_shape_list(input_ids, expected_rank=2) batch_size = input_shape[0] seq_length = input_shape[1] # Tilda embeddings for SMART algorithm tilda_embeddings = None use_tilda_embedding = kwargs.get('use_tilda_embedding') if use_tilda_embedding: with tf.variable_scope('', reuse=True): tilda_embeddings = tf.get_variable('tilda_embeddings') with tf.variable_scope(scope): with tf.variable_scope('embeddings'): (self.embedding_output, self.embedding_table) = \ self.embedding_lookup( input_ids=input_ids, vocab_size=config.vocab_size, batch_size=batch_size, max_seq_length=seq_length, embedding_size=config.hidden_size, initializer_range=config.initializer_range, word_embedding_name='word_embeddings', tilda_embeddings=tilda_embeddings, trainable=trainable) self.embedding_output = self.embedding_postprocessor( input_tensor=self.embedding_output, batch_size=batch_size, max_seq_length=seq_length, hidden_size=config.hidden_size, use_token_type=True, segment_ids=segment_ids, token_type_vocab_size=config.type_vocab_size, token_type_embedding_name='token_type_embeddings', use_position_embeddings=True, position_embedding_name='position_embeddings', initializer_range=config.initializer_range, max_position_embeddings=config.max_position_embeddings, dropout_prob=config.hidden_dropout_prob, trainable=trainable) with tf.variable_scope('encoder'): # stacked transformer attention_mask = self.create_attention_mask_from_input_mask( input_mask, batch_size, seq_length) self.all_encoder_layers = self.transformer_model( input_tensor=self.embedding_output, batch_size=batch_size, max_seq_length=seq_length, attention_mask=attention_mask, hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers, num_attention_heads=config.num_attention_heads, intermediate_size=config.intermediate_size, intermediate_act_fn=util.get_activation(config.hidden_act), hidden_dropout_prob=config.hidden_dropout_prob, attention_probs_dropout_prob=\ config.attention_probs_dropout_prob, initializer_range=config.initializer_range, trainable=trainable) # projection with tf.variable_scope('projection'): transformer_output = tf.layers.dense( self.all_encoder_layers[-1], reduced_size, activation=util.gelu, kernel_initializer=tf.truncated_normal_initializer( stddev=config.initializer_range), trainable=trainable) transformer_output = tf.reshape(transformer_output, [batch_size, -1]) input_length = tf.reduce_sum(input_mask, axis=-1) input_length = tf.cast(input_length, tf.float32) input_length_1d = tf.reshape(input_length, [batch_size]) input_length_2d = tf.reshape(input_length, [batch_size, 1]) broadcast_mask = tf.sequence_mask( tf.multiply(input_length_1d, reduced_size), seq_length * reduced_size, dtype=tf.float32) broadcast_mask = tf.multiply(broadcast_mask, seq_length / input_length_2d) transformer_output *= broadcast_mask # latent space miu = tf.layers.dense( transformer_output, topic_size, activation='tanh', kernel_initializer=tf.truncated_normal_initializer( stddev=config.initializer_range), name='miu', trainable=trainable) sigma = tf.layers.dense( transformer_output, topic_size, kernel_initializer=tf.truncated_normal_initializer( stddev=config.initializer_range), name='sigma', trainable=trainable) self.probs['miu'] = miu self.probs['sigma'] = sigma with tf.variable_scope('decoder'): with tf.variable_scope('projection'): # reparametarization if is_training: noise = tf.random_normal([batch_size, topic_size]) else: noise = tf.random_uniform([batch_size, topic_size], minval=-bias, maxval=bias) decoder_input = miu + tf.exp(sigma) * noise # projection decoder_input = tf.layers.dense( decoder_input, seq_length * reduced_size, activation=util.gelu, kernel_initializer=tf.truncated_normal_initializer( stddev=config.initializer_range), trainable=trainable) intermediate_input = tf.reshape( decoder_input, [-1, seq_length, reduced_size]) intermediate_input = util.layer_norm(intermediate_input, trainable=trainable) intermediate_input = util.dropout( intermediate_input, config.hidden_dropout_prob) # MLP with tf.variable_scope('intermediate'): intermediate_output = tf.layers.dense( intermediate_input, 4 * reduced_size, activation=util.gelu, kernel_initializer=util.create_initializer( config.initializer_range), trainable=trainable) with tf.variable_scope('output'): decoder_output = tf.layers.dense( intermediate_output, config.hidden_size, kernel_initializer=util.create_initializer( config.initializer_range), trainable=trainable) decoder_output = util.layer_norm(decoder_output, trainable=trainable) decoder_output = util.dropout(decoder_output, config.hidden_dropout_prob) self.all_decoder_layers = [intermediate_output, decoder_output] self.all_decoder_layers = [decoder_output] # reconstruction with tf.variable_scope('cls/predictions'): with tf.variable_scope('transform'): input_tensor = tf.layers.dense( decoder_output, units=config.hidden_size, activation=util.get_activation(config.hidden_act), kernel_initializer=util.create_initializer( config.initializer_range), trainable=trainable) input_tensor = util.layer_norm(input_tensor, trainable=trainable) output_weights = self.embedding_table output_bias = tf.get_variable('output_bias', shape=[config.vocab_size], initializer=tf.zeros_initializer(), trainable=trainable) flatten_input_tensor = tf.reshape(input_tensor, [-1, config.hidden_size]) logits = tf.matmul(flatten_input_tensor, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape(logits, [batch_size, seq_length, config.vocab_size]) probs = tf.nn.softmax(logits, axis=-1, name='probs') lm_log_probs = tf.nn.log_softmax(logits, axis=-1) self.preds['preds'] = tf.argmax(probs, axis=-1) one_hot_labels = tf.one_hot(input_ids, depth=config.vocab_size, dtype=tf.float32) per_example_loss = -tf.reduce_sum(lm_log_probs * one_hot_labels, axis=[-1]) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) self.total_loss = (tf.reduce_mean(per_example_loss) + tf.reduce_mean(tf.square(miu)) + tf.reduce_mean(tf.exp(sigma) - sigma - 1)) self.losses['losses'] = per_example_loss