def _forward(dilated_ids, dilated_mask): logits = self._bert_forward( bert_config, dilated_ids, dilated_mask, batch_size, dilated_seq_length, tilda_embeddings=tilda_embeddings) output_ids = tf.argmax(logits, axis=-1) output_ids = tf.cast(output_ids, dtype=tf.int32) # special padding (using `spad` token) equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32) equal_zero = tf.reduce_sum(equal_zero, axis=-1) right_pad = spad_id * tf.sequence_mask( equal_zero, dilated_seq_length, dtype=tf.int32) paded = tf.concat([output_ids, right_pad], axis=-1) # extract ids of length `max_seq_length` flattened_padded = tf.reshape(paded, [-1]) is_valid = tf.cast(tf.greater(flattened_padded, 0), dtype=tf.int32) flattened_valid = tf.boolean_mask(flattened_padded, is_valid) valid = tf.reshape(flattened_valid, [batch_size, dilated_seq_length]) cutted_valid = valid[:, :max_seq_length] # replace `spad` token with `pad` non_spad_mask = tf.cast(tf.not_equal( cutted_valid, spad_id), dtype=tf.int32) output_ids = cutted_valid * non_spad_mask output_length = tf.reduce_sum(non_spad_mask, axis=-1) # dilate reshaped_ids = tf.reshape(output_ids, [batch_size, max_seq_length, 1]) reshaped_mask = tf.reshape( tf.sequence_mask(output_length, max_seq_length, dtype=tf.int32), [batch_size, max_seq_length, 1]) concat_ids = tf.concat( [reshaped_ids, tf.zeros_like(reshaped_ids)], axis=-1) concat_mask = tf.concat([ reshaped_mask, tf.zeros_like(reshaped_mask, dtype=tf.int32) ], axis=-1) dilated_ids = tf.reshape(concat_ids, [batch_size, max_seq_length * 2]) dilated_mask = tf.reshape(concat_mask, [batch_size, max_seq_length * 2]) return dilated_ids, dilated_mask
def crf_binary_score(tag_indices, sequence_lengths, transition_params): ''' Computes the binary scores of tag sequences. Args: tag_indices: A [batch_size, max_seq_len] matrix of tag indices. sequence_lengths: A [batch_size] vector of true sequence lengths. transition_params: A [num_tags, num_tags] matrix of binary potentials. Returns: binary_scores: A [batch_size] vector of binary scores. ''' # Get shape information. num_tags = transition_params.get_shape()[0] num_transitions = tf.shape(tag_indices)[1] - 1 # Truncate by one on each side of the sequence to get the start and end # indices of each transition. start_tag_indices = tf.slice(tag_indices, [0, 0], [-1, num_transitions]) end_tag_indices = tf.slice(tag_indices, [0, 1], [-1, num_transitions]) # Encode the indices in a flattened representation. flattened_transition_indices = \ start_tag_indices * num_tags + end_tag_indices flattened_transition_params = tf.reshape(transition_params, [-1]) # Get the binary scores based on the flattened representation. binary_scores = tf.gather(flattened_transition_params, flattened_transition_indices) masks = tf.sequence_mask(sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=tf.float32) truncated_masks = tf.slice(masks, [0, 1], [-1, -1]) binary_scores = tf.reduce_sum(binary_scores * truncated_masks, 1) return binary_scores
def crf_unary_score(tag_indices, sequence_lengths, inputs): ''' Computes the unary scores of tag sequences. Args: tag_indices: A [batch_size, max_seq_len] matrix of tag indices. sequence_lengths: A [batch_size] vector of true sequence lengths. inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials. Returns: unary_scores: A [batch_size] vector of unary scores. ''' batch_size = tf.shape(inputs)[0] max_seq_len = tf.shape(inputs)[1] num_tags = tf.shape(inputs)[2] flattened_inputs = tf.reshape(inputs, [-1]) offsets = tf.expand_dims(tf.range(batch_size) * max_seq_len * num_tags, 1) offsets += tf.expand_dims(tf.range(max_seq_len) * num_tags, 0) # Use int32 or int64 based on tag_indices' dtype. if tag_indices.dtype == tf.int64: offsets = tf.cast(offsets, tf.int64) flattened_tag_indices = tf.reshape(offsets + tag_indices, [-1]) unary_scores = tf.reshape( tf.gather(flattened_inputs, flattened_tag_indices), [batch_size, max_seq_len]) masks = tf.sequence_mask(sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=tf.float32) unary_scores = tf.reduce_sum(unary_scores * masks, 1) return unary_scores
def _forward(dilated_ids, dilated_mask): logits = self._bert_forward( bert_config, dilated_ids, dilated_mask, batch_size, dilated_seq_length, tilda_embeddings=tilda_embeddings) output_ids = tf.argmax(logits, axis=-1) output_ids = tf.cast(output_ids, dtype=tf.int32) equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32) equal_zero = tf.reduce_sum(equal_zero, axis=-1) right_pad = spad_id * tf.sequence_mask( equal_zero, dilated_seq_length, dtype=tf.int32) paded = tf.concat([output_ids, right_pad], axis=-1) flattened_padded = tf.reshape(paded, [-1]) is_valid = tf.cast(tf.greater(flattened_padded, 0), dtype=tf.int32) flattened_valid = tf.boolean_mask(flattened_padded, is_valid) valid = tf.reshape(flattened_valid, [batch_size, dilated_seq_length]) cutted_valid = valid[:, :max_seq_length] nonpad_mask = tf.cast(tf.not_equal(cutted_valid, spad_id), dtype=tf.int32) output_ids = cutted_valid * nonpad_mask reshaped = tf.reshape(output_ids, [batch_size, max_seq_length, 1]) concatenated = tf.concat( [reshaped, tf.zeros_like(reshaped)], axis=-1) dilated_ids = tf.reshape(concatenated, [batch_size, max_seq_length * 2]) input_mask = tf.reduce_sum(nonpad_mask, axis=-1) dilated_mask = tf.sequence_mask(input_mask, dilated_seq_length, dtype=tf.int32) return dilated_ids, dilated_mask
def create_attention_mask_from_input_mask(self, input_mask, batch_size, max_seq_length, dtype=tf.float32): if self._mode == 'bi': to_mask = tf.cast(tf.reshape( input_mask, [batch_size, 1, max_seq_length]), dtype=dtype) broadcast_ones = tf.ones( shape=[batch_size, max_seq_length, 1], dtype=dtype) mask = broadcast_ones * to_mask elif self._mode == 'l2r': arange = tf.range(max_seq_length) + 1 to_mask = tf.cast(tf.sequence_mask(arange, max_seq_length), dtype) to_mask = tf.reshape(to_mask, [1, max_seq_length, max_seq_length]) mask = tf.tile(to_mask, [batch_size, 1, 1]) elif self._mode == 'r2l': to_mask = tf.cast(tf.reshape( input_mask, [batch_size, 1, max_seq_length]), dtype=dtype) broadcast_ones = tf.ones( shape=[batch_size, max_seq_length, 1], dtype=dtype) cover_mask = broadcast_ones * to_mask arange = tf.range(max_seq_length) reverse = tf.cast(tf.sequence_mask(arange, max_seq_length), dtype) reverse = tf.reshape(reverse, [1, max_seq_length, max_seq_length]) reverse_mask = tf.tile(reverse, [batch_size, 1, 1]) mask = (1 - reverse_mask) * cover_mask elif self._mode == 's2s': mask = tf.cast( tf.sequence_mask(input_mask, max_seq_length), dtype) return mask
def __init__(self, vocab_size, is_training, input_ids, input_mask, segment_ids, sample_weight=None, reduced_size=64, topic_size=1024, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, bias=0, scope='vae', trainable=True, **kwargs): super().__init__() # freeze parameters config = Config(vocab_size, hidden_size=hidden_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads) if not is_training: config.hidden_dropout_prob = 0.0 config.attention_probs_dropout_prob = 0.0 input_shape = util.get_shape_list(input_ids, expected_rank=2) batch_size = input_shape[0] seq_length = input_shape[1] # Tilda embeddings for SMART algorithm tilda_embeddings = None use_tilda_embedding = kwargs.get('use_tilda_embedding') if use_tilda_embedding: with tf.variable_scope('', reuse=True): tilda_embeddings = tf.get_variable('tilda_embeddings') with tf.variable_scope(scope): with tf.variable_scope('embeddings'): (self.embedding_output, self.embedding_table) = \ self.embedding_lookup( input_ids=input_ids, vocab_size=config.vocab_size, batch_size=batch_size, max_seq_length=seq_length, embedding_size=config.hidden_size, initializer_range=config.initializer_range, word_embedding_name='word_embeddings', tilda_embeddings=tilda_embeddings, trainable=trainable) self.embedding_output = self.embedding_postprocessor( input_tensor=self.embedding_output, batch_size=batch_size, max_seq_length=seq_length, hidden_size=config.hidden_size, use_token_type=True, segment_ids=segment_ids, token_type_vocab_size=config.type_vocab_size, token_type_embedding_name='token_type_embeddings', use_position_embeddings=True, position_embedding_name='position_embeddings', initializer_range=config.initializer_range, max_position_embeddings=config.max_position_embeddings, dropout_prob=config.hidden_dropout_prob, trainable=trainable) with tf.variable_scope('encoder'): # stacked transformer attention_mask = self.create_attention_mask_from_input_mask( input_mask, batch_size, seq_length) self.all_encoder_layers = self.transformer_model( input_tensor=self.embedding_output, batch_size=batch_size, max_seq_length=seq_length, attention_mask=attention_mask, hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers, num_attention_heads=config.num_attention_heads, intermediate_size=config.intermediate_size, intermediate_act_fn=util.get_activation(config.hidden_act), hidden_dropout_prob=config.hidden_dropout_prob, attention_probs_dropout_prob=\ config.attention_probs_dropout_prob, initializer_range=config.initializer_range, trainable=trainable) # projection with tf.variable_scope('projection'): transformer_output = tf.layers.dense( self.all_encoder_layers[-1], reduced_size, activation=util.gelu, kernel_initializer=tf.truncated_normal_initializer( stddev=config.initializer_range), trainable=trainable) transformer_output = tf.reshape(transformer_output, [batch_size, -1]) input_length = tf.reduce_sum(input_mask, axis=-1) input_length = tf.cast(input_length, tf.float32) input_length_1d = tf.reshape(input_length, [batch_size]) input_length_2d = tf.reshape(input_length, [batch_size, 1]) broadcast_mask = tf.sequence_mask( tf.multiply(input_length_1d, reduced_size), seq_length * reduced_size, dtype=tf.float32) broadcast_mask = tf.multiply(broadcast_mask, seq_length / input_length_2d) transformer_output *= broadcast_mask # latent space miu = tf.layers.dense( transformer_output, topic_size, activation='tanh', kernel_initializer=tf.truncated_normal_initializer( stddev=config.initializer_range), name='miu', trainable=trainable) sigma = tf.layers.dense( transformer_output, topic_size, kernel_initializer=tf.truncated_normal_initializer( stddev=config.initializer_range), name='sigma', trainable=trainable) self.probs['miu'] = miu self.probs['sigma'] = sigma with tf.variable_scope('decoder'): with tf.variable_scope('projection'): # reparametarization if is_training: noise = tf.random_normal([batch_size, topic_size]) else: noise = tf.random_uniform([batch_size, topic_size], minval=-bias, maxval=bias) decoder_input = miu + tf.exp(sigma) * noise # projection decoder_input = tf.layers.dense( decoder_input, seq_length * reduced_size, activation=util.gelu, kernel_initializer=tf.truncated_normal_initializer( stddev=config.initializer_range), trainable=trainable) intermediate_input = tf.reshape( decoder_input, [-1, seq_length, reduced_size]) intermediate_input = util.layer_norm(intermediate_input, trainable=trainable) intermediate_input = util.dropout( intermediate_input, config.hidden_dropout_prob) # MLP with tf.variable_scope('intermediate'): intermediate_output = tf.layers.dense( intermediate_input, 4 * reduced_size, activation=util.gelu, kernel_initializer=util.create_initializer( config.initializer_range), trainable=trainable) with tf.variable_scope('output'): decoder_output = tf.layers.dense( intermediate_output, config.hidden_size, kernel_initializer=util.create_initializer( config.initializer_range), trainable=trainable) decoder_output = util.layer_norm(decoder_output, trainable=trainable) decoder_output = util.dropout(decoder_output, config.hidden_dropout_prob) self.all_decoder_layers = [intermediate_output, decoder_output] self.all_decoder_layers = [decoder_output] # reconstruction with tf.variable_scope('cls/predictions'): with tf.variable_scope('transform'): input_tensor = tf.layers.dense( decoder_output, units=config.hidden_size, activation=util.get_activation(config.hidden_act), kernel_initializer=util.create_initializer( config.initializer_range), trainable=trainable) input_tensor = util.layer_norm(input_tensor, trainable=trainable) output_weights = self.embedding_table output_bias = tf.get_variable('output_bias', shape=[config.vocab_size], initializer=tf.zeros_initializer(), trainable=trainable) flatten_input_tensor = tf.reshape(input_tensor, [-1, config.hidden_size]) logits = tf.matmul(flatten_input_tensor, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape(logits, [batch_size, seq_length, config.vocab_size]) probs = tf.nn.softmax(logits, axis=-1, name='probs') lm_log_probs = tf.nn.log_softmax(logits, axis=-1) self.preds['preds'] = tf.argmax(probs, axis=-1) one_hot_labels = tf.one_hot(input_ids, depth=config.vocab_size, dtype=tf.float32) per_example_loss = -tf.reduce_sum(lm_log_probs * one_hot_labels, axis=[-1]) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) self.total_loss = (tf.reduce_mean(per_example_loss) + tf.reduce_mean(tf.square(miu)) + tf.reduce_mean(tf.exp(sigma) - sigma - 1)) self.losses['losses'] = per_example_loss
def __init__(self, is_training, input_tensor, n_wide_features, wide_features, label_ids, label_size=2, sample_weight=None, scope='cls/seq_relationship', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, **kwargs): super().__init__(**kwargs) hidden_size = input_tensor.shape.as_list()[-1] feature_size = wide_features.shape.as_list()[-1] with tf.variable_scope('wide'): feature_embeddings = tf.get_variable( name='feature_embeddings', shape=[feature_size + 1, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) wide_output = tf.gather(feature_embeddings, wide_features) # [B, N, H] with tf.variable_scope('wide_and_deep'): deep_output = tf.expand_dims(input_tensor, -1) # [B, H, 1] attention_scores = tf.matmul(wide_output, deep_output) # [B, N, 1] attention_scores = tf.transpose(attention_scores, [0, 2, 1]) # [B, 1, N] attention_scores = tf.multiply(attention_scores, 1.0 / math.sqrt(hidden_size)) feature_mask = tf.cast( tf.sequence_mask(n_wide_features, feature_size), tf.float32) # [B, N] feature_mask = tf.expand_dims(feature_mask, 1) # [B, 1, N] attention_scores += (1.0 - feature_mask) * -10000.0 attention_matrix = tf.nn.softmax(attention_scores, axis=-1) attention_output = tf.matmul(attention_matrix, wide_output) # [B, 1, H] attention_output = attention_output[:, 0, :] # [B, H] # attention_output = util.dropout( # attention_output, hidden_dropout_prob) input_tensor = util.layer_norm(attention_output + input_tensor, trainable=trainable) with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[label_size, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) self.preds['preds'] = tf.argmax(logits, axis=-1) self.probs['probs'] = tf.nn.softmax(logits, axis=-1, name='probs') log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=label_size, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) if sample_weight is not None: per_example_loss = tf.cast(sample_weight, dtype=tf.float32) * per_example_loss thresh = kwargs.get('tsa_thresh') if thresh is not None: assert isinstance( thresh, float), ('`tsa_thresh` must be a float between 0 and 1.') uncertainty = tf.reduce_sum(self.probs['probs'] * tf.log(self.probs['probs']), axis=-1) uncertainty /= tf.log(1 / label_size) per_example_loss = tf.cast( tf.greater(uncertainty, thresh), dtype=tf.float32) * \ per_example_loss self.losses['losses'] = per_example_loss self.total_loss = tf.reduce_mean(per_example_loss)
def __init__(self, bert_config, is_training, dilated_ids, label_ids, max_seq_length, spad_id=1, loop=3, sample_weight=None, scope='dilated', use_tilda_embedding=False, **kwargs): super().__init__() dilated_mask = tf.cast(tf.not_equal(dilated_ids, 0), tf.float32) shape = util.get_shape_list(dilated_ids, expected_rank=2) batch_size = shape[0] dilated_seq_length = shape[1] # Tilda embeddings for SMART algorithm tilda_embeddings = None if use_tilda_embedding: with tf.variable_scope('', reuse=True): tilda_embeddings = tf.get_variable('tilda_embeddings') with tf.variable_scope(scope): # forward once if is_training: logits = self._bert_forward(bert_config, dilated_ids, dilated_mask, batch_size, dilated_seq_length, tilda_embeddings=tilda_embeddings) self.preds['LM'] = tf.argmax(logits, axis=-1) # LM loss log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=bert_config.vocab_size) per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) input_length = tf.reduce_sum(dilated_mask, axis=-1) * 2 label_mask = tf.sequence_mask(input_length, max_seq_length * 2, dtype=tf.float32) per_example_loss = \ tf.reduce_sum(per_token_loss * label_mask, axis=-1) / \ tf.reduce_sum(label_mask, axis=-1) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) self.total_loss = tf.reduce_mean(per_example_loss) self.losses['LM'] = per_example_loss # forward loop else: def _forward(dilated_ids, dilated_mask): logits = self._bert_forward( bert_config, dilated_ids, dilated_mask, batch_size, dilated_seq_length, tilda_embeddings=tilda_embeddings) output_ids = tf.argmax(logits, axis=-1) output_ids = tf.cast(output_ids, dtype=tf.int32) # special padding (using `spad` token) equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32) equal_zero = tf.reduce_sum(equal_zero, axis=-1) right_pad = spad_id * tf.sequence_mask( equal_zero, dilated_seq_length, dtype=tf.int32) paded = tf.concat([output_ids, right_pad], axis=-1) # extract ids of length `max_seq_length` flattened_padded = tf.reshape(paded, [-1]) is_valid = tf.cast(tf.greater(flattened_padded, 0), dtype=tf.int32) flattened_valid = tf.boolean_mask(flattened_padded, is_valid) valid = tf.reshape(flattened_valid, [batch_size, dilated_seq_length]) cutted_valid = valid[:, :max_seq_length] # replace `spad` token with `pad` non_spad_mask = tf.cast(tf.not_equal( cutted_valid, spad_id), dtype=tf.int32) output_ids = cutted_valid * non_spad_mask output_length = tf.reduce_sum(non_spad_mask, axis=-1) # dilate reshaped_ids = tf.reshape(output_ids, [batch_size, max_seq_length, 1]) reshaped_mask = tf.reshape( tf.sequence_mask(output_length, max_seq_length, dtype=tf.int32), [batch_size, max_seq_length, 1]) concat_ids = tf.concat( [reshaped_ids, tf.zeros_like(reshaped_ids)], axis=-1) concat_mask = tf.concat([ reshaped_mask, tf.zeros_like(reshaped_mask, dtype=tf.int32) ], axis=-1) dilated_ids = tf.reshape(concat_ids, [batch_size, max_seq_length * 2]) dilated_mask = tf.reshape(concat_mask, [batch_size, max_seq_length * 2]) return dilated_ids, dilated_mask for _ in range(loop): dilated_ids, dilated_mask = _forward( dilated_ids, dilated_mask) self.preds['LM'] = dilated_ids