def scatter_update(sequence, updates, positions): '''Scatter-update a sequence. Args: sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor updates: A tensor of size batch_size*seq_len(*depth) positions: A [batch_size, n_positions] tensor Returns: A tuple of two tensors. First is a [batch_size, seq_len] or [batch_size, seq_len, depth] tensor of 'sequence' with elements at 'positions' replaced by the values at 'updates.' Updates to index 0 are ignored. If there are duplicated positions the update is only applied once. Second is a [batch_size, seq_len] mask tensor of which inputs were updated. ''' shape = util.get_shape_list(sequence, expected_rank=[2, 3]) depth_dimension = (len(shape) == 3) if depth_dimension: B, L, D = shape else: B, L = shape D = 1 sequence = tf.expand_dims(sequence, -1) N = util.get_shape_list(positions)[1] shift = tf.expand_dims(L * tf.range(B), -1) flat_positions = tf.reshape(positions + shift, [-1, 1]) flat_updates = tf.reshape(updates, [-1, D]) updates = tf.scatter_nd(flat_positions, flat_updates, [B * L, D]) updates = tf.reshape(updates, [B, L, D]) flat_updates_mask = tf.ones([B * N], tf.int32) updates_mask = tf.scatter_nd(flat_positions, flat_updates_mask, [B * L]) updates_mask = tf.reshape(updates_mask, [B, L]) not_first_token = tf.concat( [tf.zeros((B, 1), tf.int32), tf.ones((B, L - 1), tf.int32)], -1) updates_mask *= not_first_token updates_mask_3d = tf.expand_dims(updates_mask, -1) # account for duplicate positions if sequence.dtype == tf.float32: updates_mask_3d = tf.cast(updates_mask_3d, tf.float32) updates /= tf.maximum(1.0, updates_mask_3d) else: assert sequence.dtype == tf.int32 updates = tf.divide(updates, tf.maximum(1, updates_mask_3d)) updates = tf.cast(updates, tf.int32) updates_mask = tf.minimum(updates_mask, 1) updates_mask_3d = tf.minimum(updates_mask_3d, 1) updated_sequence = (((1 - updates_mask_3d) * sequence) + (updates_mask_3d * updates)) if not depth_dimension: updated_sequence = tf.squeeze(updated_sequence, -1) return updated_sequence, updates_mask
def dot_product_attention(q, k, v, bias, dropout_rate=0.0): """Dot-product attention. Args: q: Tensor with shape [..., length_q, depth_k]. k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must match with q. v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must match with q. bias: bias Tensor (see attention_bias()) dropout_rate: a float. Returns: Tensor with shape [..., length_q, depth_v]. """ logits = tf.matmul(q, k, transpose_b=True) # [..., length_q, length_kv] logits = tf.multiply(logits, 1.0 / math.sqrt(float(util.get_shape_list(q)[-1]))) if bias is not None: # `attention_mask` = [B, T] from_shape = util.get_shape_list(q) if len(from_shape) == 4: broadcast_ones = tf.ones([from_shape[0], 1, from_shape[2], 1], tf.float32) elif len(from_shape) == 5: # from_shape = [B, N, Block_num, block_size, depth]# broadcast_ones = tf.ones( [from_shape[0], 1, from_shape[2], from_shape[3], 1], tf.float32) bias = tf.matmul(broadcast_ones, tf.cast(bias, tf.float32), transpose_b=True) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. adder = (1.0 - bias) * -10000.0 # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. logits += adder else: adder = 0.0 attention_probs = tf.nn.softmax(logits, name="attention_probs") attention_probs = util.dropout(attention_probs, dropout_rate) return tf.matmul(attention_probs, v)
def create_attention_mask_from_input_mask(from_tensor, to_mask): '''Create 3D attention mask from a 2D tensor mask. Args: from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. to_mask: int32 Tensor of shape [batch_size, to_seq_length]. Returns: float Tensor of shape [batch_size, from_seq_length, to_seq_length]. ''' from_shape = util.get_shape_list(from_tensor, expected_rank=[2, 3]) batch_size = from_shape[0] from_seq_length = from_shape[1] to_shape = util.get_shape_list(to_mask, expected_rank=2) to_seq_length = to_shape[1] to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32) # We don't assume that `from_tensor` is a mask (although it could be). We # don't actually care if we attend *from* padding tokens (only *to* padding) # tokens so we create a tensor of all ones. # # `broadcast_ones` = [batch_size, from_seq_length, 1] broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1], dtype=tf.float32) # Here we broadcast along two dimensions to create the mask. mask = broadcast_ones * to_mask return mask
def create_attention_mask_from_input_mask(to_mask, batch_size, max_seq_length, dtype=tf.float32): to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, max_seq_length]), dtype=dtype) broadcast_ones = tf.ones(shape=[batch_size, max_seq_length, 1], dtype=dtype) mask = broadcast_ones * to_mask return mask
def _forward(self, is_training, split_placeholders, **kwargs): if not is_training: return super()._forward(is_training, split_placeholders, **kwargs) aug_input_ids = tf.boolean_mask( split_placeholders['aug_input_ids'], mask=(1.0 - split_placeholders['is_supervised']), axis=0) aug_input_mask = tf.boolean_mask( split_placeholders['aug_input_mask'], mask=(1.0 - split_placeholders['is_supervised']), axis=0) aug_segment_ids = tf.boolean_mask( split_placeholders['aug_segment_ids'], mask=(1.0 - split_placeholders['is_supervised']), axis=0) input_ids = tf.concat([split_placeholders['input_ids'], aug_input_ids], axis=0) input_mask = tf.concat( [split_placeholders['input_mask'], aug_input_mask], axis=0) segment_ids = tf.concat( [split_placeholders['segment_ids'], aug_segment_ids], axis=0) encoder = BERTEncoder(bert_config=self.bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, scope='bert', drop_pooler=self._drop_pooler, **kwargs) encoder_output = encoder.get_pooled_output() label_ids = split_placeholders['label_ids'] is_expanded = tf.zeros_like(label_ids, dtype=tf.float32) batch_size = util.get_shape_list(aug_input_ids)[0] aug_is_expanded = tf.ones((batch_size), dtype=tf.float32) is_expanded = tf.concat([is_expanded, aug_is_expanded], axis=0) decoder = UDADecoder( is_training=is_training, input_tensor=encoder_output, is_supervised=split_placeholders['is_supervised'], is_expanded=is_expanded, label_ids=label_ids, label_size=self.label_size, sample_weight=split_placeholders.get('sample_weight'), scope='cls/seq_relationship', global_step=self._global_step, num_train_steps=self.total_steps, uda_softmax_temp=self._uda_softmax_temp, uda_confidence_thresh=self._uda_confidence_thresh, tsa_schedule=self._tsa_schedule, **kwargs) (total_loss, losses, probs, preds) = decoder.get_forward_outputs() return (total_loss, losses, probs, preds)
def noncausal_denominator(qs, ks): '''Computes FAVOR normalizer in noncausal attention. Args: qs: query_prime tensor of the shape [L,B,H,M]. ks: key_prime tensor of the shape [L,B,H,M]. Returns: FAVOR normalizer in noncausal attention. ''' all_ones = tf.ones([ks.shape[0]]) ks_sum = tf.einsum('lbhm,l->bhm', ks, all_ones) return tf.einsum('lbhm,bhm->lbh', qs, ks_sum)
def _create_mask(qlen, mlen, dtype=tf.float32, same_length=False): '''create causal attention mask.''' attn_mask = tf.ones([qlen, qlen], dtype=dtype) mask_u = tf.matrix_band_part(attn_mask, 0, -1) mask_dia = tf.matrix_band_part(attn_mask, 0, 0) attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype) ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1) if same_length: mask_l = tf.matrix_band_part(attn_mask, -1, 0) ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1) return ret
def create_projection_matrix(m, d, seed=0, scaling=0, struct_mode=False): r'''Constructs the matrix of random projections. Constructs a matrix of random orthogonal projections. Each projection vector has direction chosen uniformly at random and either deterministic length \sqrt{d} or length taken from the \chi(d) distribution (in the latter case marginal distributions of the projections are d-dimensional Gaussian vectors with associated identity covariance matrix). Args: m: number of random projections. d: dimensionality of each random projection. seed: random seed used to construct projections. scaling: 1 if all the random projections need to be renormalized to have length \sqrt{d}, 0 if the lengths of random projections should follow \chi(d) distribution. struct_mode: if True then products of Givens rotations will be used to construct random orthogonal matrix. This bypasses Gram-Schmidt orthogonalization. Returns: The matrix of random projections of the shape [m, d]. ''' nb_full_blocks = int(m / d) block_list = [] current_seed = seed for _ in range(nb_full_blocks): if struct_mode: q = create_products_of_givens_rotations(d, seed) else: unstructured_block = tf.random_normal((d, d), seed=current_seed) q, _ = tf.linalg.qr(unstructured_block) q = tf.transpose(q) block_list.append(q) current_seed += 1 remaining_rows = m - nb_full_blocks * d if remaining_rows > 0: if struct_mode: q = create_products_of_givens_rotations(d, seed) else: unstructured_block = tf.random_normal((d, d), seed=current_seed) q, _ = tf.linalg.qr(unstructured_block) q = tf.transpose(q) block_list.append(q[0:remaining_rows]) final_matrix = tf.concat(block_list, axis=0) current_seed += 1 if scaling == 0: multiplier = tf.norm(tf.random_normal((m, d), seed=current_seed), axis=1) elif scaling == 1: multiplier = 1 / tf.math.rsqrt(float(d)) * tf.ones((m)) else: raise ValueError('Scaling must be one of {0, 1}. Was %s' % scaling) return tf.matmul(tf.linalg.diag(multiplier), final_matrix)
def create_attention_mask_from_input_mask(self, input_mask, batch_size, max_seq_length, dtype=tf.float32): if self._mode == 'bi': to_mask = tf.cast(tf.reshape( input_mask, [batch_size, 1, max_seq_length]), dtype=dtype) broadcast_ones = tf.ones( shape=[batch_size, max_seq_length, 1], dtype=dtype) mask = broadcast_ones * to_mask elif self._mode == 'l2r': arange = tf.range(max_seq_length) + 1 to_mask = tf.cast(tf.sequence_mask(arange, max_seq_length), dtype) to_mask = tf.reshape(to_mask, [1, max_seq_length, max_seq_length]) mask = tf.tile(to_mask, [batch_size, 1, 1]) elif self._mode == 'r2l': to_mask = tf.cast(tf.reshape( input_mask, [batch_size, 1, max_seq_length]), dtype=dtype) broadcast_ones = tf.ones( shape=[batch_size, max_seq_length, 1], dtype=dtype) cover_mask = broadcast_ones * to_mask arange = tf.range(max_seq_length) reverse = tf.cast(tf.sequence_mask(arange, max_seq_length), dtype) reverse = tf.reshape(reverse, [1, max_seq_length, max_seq_length]) reverse_mask = tf.tile(reverse, [batch_size, 1, 1]) mask = (1 - reverse_mask) * cover_mask elif self._mode == 's2s': mask = tf.cast( tf.sequence_mask(input_mask, max_seq_length), dtype) return mask
def create_attention_mask_from_input_mask(self, input_mask, batch_size, max_seq_length, dtype=tf.float32): to_mask = tf.cast(tf.reshape(input_mask, [batch_size, 1, max_seq_length]), dtype=dtype) broadcast_ones = tf.ones(shape=[batch_size, max_seq_length, 1], dtype=dtype) mask = broadcast_ones * to_mask broadcast_eye = tf.tile( tf.reshape(tf.eye(max_seq_length), [1, max_seq_length, max_seq_length]), [batch_size, 1, 1]) mask += broadcast_eye mask = tf.cast(tf.greater(mask, 0), dtype) return mask
def _local_perm(inputs, targets, is_masked, perm_size, seq_len): ''' Sample a permutation of the factorization order, and create an attention mask accordingly. Args: inputs: int64 Tensor in shape [seq_len], input ids. targets: int64 Tensor in shape [seq_len], target ids. is_masked: bool Tensor in shape [seq_len]. True means being selected for partial prediction. perm_size: the length of longest permutation. Could be set to be reuse_len. Should not be larger than reuse_len or there will be data leaks. seq_len: int, sequence length. ''' batch_size = tf.shape(inputs)[0] # Generate permutation indices index = tf.range(seq_len, dtype=tf.int64) index = tf.reshape(index, [-1, perm_size]) index = tf.transpose(index) index = tf.random_shuffle(index) index = tf.transpose(index) index = tf.reshape(index, [1, -1]) index = tf.tile(index, [batch_size, 1]) # `perm_mask` and `target_mask` # non-functional tokens non_func_tokens = tf.logical_not( tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID))) non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens) masked_or_func_tokens = tf.logical_not(non_mask_tokens) # Set the permutation indices of non-masked (& non-funcional) tokens to the # smallest index (-1): # (1) they can be seen by all other positions # (2) they cannot see masked positions, so there won't be information leak smallest_index = -tf.ones([batch_size, seq_len], dtype=tf.int64) rev_index = tf.where(non_mask_tokens, smallest_index, index) # Create `target_mask`: non-funcional and maksed tokens # 1: use mask as input and have loss # 0: use token (or [SEP], [CLS]) as input and do not have loss target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens) target_mask = tf.cast(target_tokens, tf.float32) # Create `perm_mask` # `target_tokens` cannot see themselves self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1) # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens) # 0: can attend if i > j or j is non-masked perm_mask = tf.logical_and( self_rev_index[:, :, None] <= rev_index[:, None, :], tf.expand_dims(masked_or_func_tokens, axis=-1)) # new target: [next token] for LM and [curr token] (self) for PLM new_targets = tf.concat([inputs[:, 0:1], targets[:, :-1]], axis=1) # construct inputs_k inputs_k = inputs # construct inputs_q inputs_q = target_mask return perm_mask, new_targets, target_mask, inputs_k, inputs_q
def _expand_features(module, split_placeholders): inputs = split_placeholders['input'] target = split_placeholders['target'] is_masked = tf.cast(split_placeholders['is_masked'], tf.bool) batch_size = tf.shape(inputs)[0] non_reuse_len = module.max_seq_length - module.reuse_seq_length assert (module.perm_size <= module.reuse_seq_length and module.perm_size <= non_reuse_len) (perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0) = \ _local_perm( inputs[:, :module.reuse_seq_length], target[:, :module.reuse_seq_length], is_masked[:, :module.reuse_seq_length], module.perm_size, module.reuse_seq_length) (perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1) = \ _local_perm( inputs[:, module.reuse_seq_length:], target[:, module.reuse_seq_length:], is_masked[:, module.reuse_seq_length:], module.perm_size, non_reuse_len) perm_mask_0 = tf.concat([ tf.cast(perm_mask_0, dtype=tf.float32), tf.ones([batch_size, module.reuse_seq_length, non_reuse_len]) ], axis=2) perm_mask_1 = tf.concat([ tf.zeros([batch_size, non_reuse_len, module.reuse_seq_length]), tf.cast(perm_mask_1, dtype=tf.float32) ], axis=2) perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=1) target = tf.concat([target_0, target_1], axis=1) target_mask = tf.concat([target_mask_0, target_mask_1], axis=1) input_k = tf.concat([input_k_0, input_k_1], axis=1) input_q = tf.concat([input_q_0, input_q_1], axis=1) if module._num_predict is not None: #TODO(geying): convert tensors from 1-D to 2-D indices = tf.range(module.max_seq_length, dtype=tf.int64) indices = tf.reshape(indices, [-1, module.max_seq_length]) indices = tf.tile(indices, [batch_size, 1]) bool_target_mask = tf.cast(target_mask, tf.bool) indices = tf.boolean_mask(indices, bool_target_mask) ##### extra padding due to CLS/SEP introduced after prepro actual_num_predict = tf.shape(indices)[1] pad_len = module._num_predict - actual_num_predict ##### target_mapping target_mapping = tf.one_hot(indices, module.max_seq_length, dtype=tf.float32) paddings = tf.zeros([pad_len, module.max_seq_length], dtype=target_mapping.dtype) target_mapping = tf.concat([target_mapping, paddings], axis=0) split_placeholders['target_mapping'] = tf.reshape( target_mapping, [-1, module._num_predict, module.max_seq_length]) ##### target target = tf.boolean_mask(target, bool_target_mask) paddings = tf.zeros([pad_len], dtype=target.dtype) target = tf.concat([target, paddings], axis=0) split_placeholders['target'] = tf.reshape(target, [-1, module._num_predict]) ##### target mask target_mask = tf.concat([ tf.ones([batch_size, actual_num_predict], dtype=tf.float32), tf.zeros([batch_size, pad_len], dtype=tf.float32) ], axis=1) split_placeholders['target_mask'] = tf.reshape( target_mask, [-1, module._num_predict]) else: split_placeholders['target'] = tf.reshape(target, [-1, module.max_seq_length]) split_placeholders['target_mask'] = tf.reshape( target_mask, [-1, module.max_seq_length]) # reshape back to fixed shape split_placeholders['perm_mask'] = tf.reshape( perm_mask, [-1, module.max_seq_length, module.max_seq_length]) split_placeholders['input_k'] = tf.reshape(input_k, [-1, module.max_seq_length]) split_placeholders['input_q'] = tf.reshape(input_q, [-1, module.max_seq_length]) return split_placeholders
def __init__(self, bert_config, is_training, input_ids, input_mask=None, token_type_ids=None, use_one_hot_embeddings=True, scope=None, embedding_size=None, input_embeddings=None, input_reprs=None, update_embeddings=True, untied_embeddings=False): '''Constructor for BertModel. Args: bert_config: `BertConfig` instance. is_training: bool. true for training model, false for eval model. Controls whether dropout will be applied. input_ids: int32 Tensor of shape [batch_size, seq_length]. input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. use_one_hot_embeddings: (optional) bool. Whether to use one-hot word embeddings or tf.embedding_lookup() for the word embeddings. On the TPU, it is much faster if this is True, on the CPU or GPU, it is faster if this is False. scope: (optional) variable scope. Defaults to 'electra'. Raises: ValueError: The config is invalid or one of the input tensor shapes is invalid. ''' bert_config = copy.deepcopy(bert_config) if not is_training: bert_config.hidden_dropout_prob = 0.0 bert_config.attention_probs_dropout_prob = 0.0 input_shape = util.get_shape_list(token_type_ids, expected_rank=2) batch_size = input_shape[0] seq_length = input_shape[1] if input_mask is None: input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) assert token_type_ids is not None if input_reprs is None: with tf.variable_scope( ((scope if untied_embeddings else 'electra') + '/embeddings'), reuse=tf.AUTO_REUSE): # Perform embedding lookup on the word ids if embedding_size is None: embedding_size = bert_config.hidden_size (token_embeddings, self.embedding_table) = \ embedding_lookup( input_ids=input_ids, vocab_size=bert_config.vocab_size, embedding_size=embedding_size, initializer_range=bert_config.initializer_range, word_embedding_name='word_embeddings', use_one_hot_embeddings=use_one_hot_embeddings) with tf.variable_scope( ((scope if untied_embeddings else 'electra') + '/embeddings'), reuse=tf.AUTO_REUSE): # Add positional embeddings and token type embeddings, then # layer normalize and perform dropout. self.embedding_output = embedding_postprocessor( input_tensor=token_embeddings, use_token_type=True, token_type_ids=token_type_ids, token_type_vocab_size=bert_config.type_vocab_size, token_type_embedding_name='token_type_embeddings', use_position_embeddings=True, position_embedding_name='position_embeddings', initializer_range=bert_config.initializer_range, max_position_embeddings=\ bert_config.max_position_embeddings, dropout_prob=bert_config.hidden_dropout_prob) else: self.embedding_output = input_reprs if not update_embeddings: self.embedding_output = tf.stop_gradient(self.embedding_output) with tf.variable_scope(scope, default_name='electra'): if self.embedding_output.shape[-1] != bert_config.hidden_size: self.embedding_output = tf.layers.dense( self.embedding_output, bert_config.hidden_size, name='embeddings_project') with tf.variable_scope('encoder'): # This converts a 2D mask of shape [batch_size, seq_length] # to a 3D mask of shape [batch_size, seq_length, seq_length] # which is used for the attention scores. attention_mask = create_attention_mask_from_input_mask( token_type_ids, input_mask) # Run the stacked transformer. Output shapes # attn_maps: # [n_layers, batch_size, n_heads, seq_length, seq_length] (self.all_layer_outputs, self.attn_maps) = transformer_model( input_tensor=self.embedding_output, attention_mask=attention_mask, hidden_size=bert_config.hidden_size, num_hidden_layers=bert_config.num_hidden_layers, num_attention_heads=bert_config.num_attention_heads, intermediate_size=bert_config.intermediate_size, intermediate_act_fn=util.get_activation( bert_config.hidden_act), hidden_dropout_prob=bert_config.hidden_dropout_prob, attention_probs_dropout_prob=bert_config. attention_probs_dropout_prob, initializer_range=bert_config.initializer_range, do_return_all_layers=True) self.sequence_output = self.all_layer_outputs[-1] self.pooled_output = self.sequence_output[:, 0]
def __init__(self, albert_config, is_training, input_ids, input_mask=None, segment_ids=None, scope='bert', drop_pooler=False, trainable=True, **kwargs): """Constructor for AlbertModel. Args: albert_config: `AlbertConfig` instance. is_training: bool. true for training model, false for eval model. Controls whether dropout will be applied. input_ids: int32 Tensor of shape [batch_size, seq_length]. input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. segment_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. use_einsum: (optional) bool. Whether to use einsum or reshape+matmul for dense layers scope: (optional) variable scope. Defaults to "bert". Raises: ValueError: The config is invalid or one of the input tensor shapes is invalid. """ albert_config = copy.deepcopy(albert_config) if not is_training: albert_config.hidden_dropout_prob = 0.0 albert_config.attention_probs_dropout_prob = 0.0 input_shape = util.get_shape_list(input_ids, expected_rank=2) batch_size = input_shape[0] seq_length = input_shape[1] if input_mask is None: input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) if segment_ids is None: segment_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) # Tilda embeddings for SMART algorithm tilda_embeddings = None use_tilda_embedding = kwargs.get('use_tilda_embedding') if use_tilda_embedding: with tf.variable_scope('', reuse=True): tilda_embeddings = tf.get_variable('tilda_embeddings') with tf.variable_scope(scope): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. (self.word_embedding_output, self.output_embedding_table) = embedding_lookup( input_ids=input_ids, vocab_size=albert_config.vocab_size, embedding_size=albert_config.embedding_size, initializer_range=albert_config.initializer_range, word_embedding_name="word_embeddings", tilda_embeddings=tilda_embeddings, trainable=trainable) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. self.embedding_output = embedding_postprocessor( input_tensor=self.word_embedding_output, use_token_type=True, segment_ids=segment_ids, token_type_vocab_size=albert_config.type_vocab_size, token_type_embedding_name="token_type_embeddings", use_position_embeddings=True, position_embedding_name="position_embeddings", initializer_range=albert_config.initializer_range, max_position_embeddings=albert_config. max_position_embeddings, dropout_prob=albert_config.hidden_dropout_prob, trainable=trainable) with tf.variable_scope("encoder"): # Run the stacked transformer. # `sequence_output` shape = [batch_size, seq_length, hidden_size]. self.all_encoder_layers = transformer_model( input_tensor=self.embedding_output, attention_mask=input_mask, hidden_size=albert_config.hidden_size, num_hidden_layers=albert_config.num_hidden_layers, num_hidden_groups=albert_config.num_hidden_groups, num_attention_heads=albert_config.num_attention_heads, intermediate_size=albert_config.intermediate_size, inner_group_num=albert_config.inner_group_num, intermediate_act_fn=util.get_activation( albert_config.hidden_act), hidden_dropout_prob=albert_config.hidden_dropout_prob, attention_probs_dropout_prob=albert_config. attention_probs_dropout_prob, initializer_range=albert_config.initializer_range, do_return_all_layers=True, use_einsum=False, trainable=trainable) self.sequence_output = self.all_encoder_layers[-1] # The "pooler" converts the encoded sequence tensor of shape # [batch_size, seq_length, hidden_size] to a tensor of shape # [batch_size, hidden_size]. This is necessary for segment-level # (or segment-pair-level) classification tasks where we need a fixed # dimensional representation of the segment. with tf.variable_scope("pooler"): # We "pool" the model by simply taking the hidden state corresponding # to the first token. We assume that this has been pre-trained first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) # trick: ignore the fully connected layer if drop_pooler: self.pooled_output = first_token_tensor else: self.pooled_output = tf.layers.dense( first_token_tensor, albert_config.hidden_size, activation=tf.tanh, kernel_initializer=util.create_initializer( albert_config.initializer_range), trainable=trainable)
def __init__(self, vocab_size, is_training, source_ids, target_ids, sos_id, sample_weight=None, hidden_size=768, num_blocks=6, num_attention_heads=12, scope='transformer', use_label_smoothing=False, use_tilda_embedding=False, trainable=True, **kwargs): super().__init__() dropout_rate = 0.0 if is_training: dropout_rate = 0.1 source_shape = util.get_shape_list(source_ids, expected_rank=2) target_shape = util.get_shape_list(target_ids, expected_rank=2) batch_size = source_shape[0] source_max_seq_length = source_shape[1] target_max_seq_length = target_shape[1] # Tilda embeddings for SMART algorithm tilda_embeddings = None if use_tilda_embedding: with tf.variable_scope('', reuse=True): tilda_embeddings = tf.get_variable('tilda_embeddings') with tf.variable_scope(scope): source_mask = tf.math.equal(source_ids, 0) # embedding with tf.variable_scope('embeddings'): (enc, embedding_table) = embedding_lookup( input_ids=source_ids, vocab_size=vocab_size, batch_size=batch_size, max_seq_length=source_max_seq_length, embedding_size=hidden_size, word_embedding_name='word_embeddings', tilda_embeddings=tilda_embeddings) enc *= hidden_size ** 0.5 # scale enc += positional_encoding(enc, source_max_seq_length) enc = util.dropout(enc, dropout_rate) with tf.variable_scope('encoder'): # stacked multi-attention layers for i in range(num_blocks): with tf.variable_scope('block_%s' % i): # self-attention enc = multihead_attention( queries=enc, keys=enc, values=enc, key_masks=source_mask, num_heads=num_attention_heads, dropout_rate=dropout_rate, training=is_training, causality=False, scope='self_attention') # feed forward enc = ff(enc, num_units=[hidden_size * 4, hidden_size]) memory = enc def _forward(target_ids, target_mask, target_max_seq_length): with tf.variable_scope('decoder'): # shared embedding dec = tf.nn.embedding_lookup(embedding_table, target_ids) dec *= hidden_size ** 0.5 # scale dec += positional_encoding(dec, target_max_seq_length) dec = util.dropout(dec, dropout_rate) # blocks for i in range(num_blocks): with tf.variable_scope('block_%s' % i): # masked self-attention dec = multihead_attention( queries=dec, keys=dec, values=dec, key_masks=target_mask, num_heads=num_attention_heads, dropout_rate=dropout_rate, training=is_training, causality=True, scope='masked_self_attention') # vanilla attention dec = multihead_attention( queries=dec, keys=memory, values=memory, key_masks=source_mask, num_heads=num_attention_heads, dropout_rate=dropout_rate, training=is_training, causality=False, scope='vanilla_attention') # feed forward dec = ff( dec, num_units=[4 * hidden_size, hidden_size]) # final linear projection (embedding weights are shared) with tf.variable_scope('cls'): output_bias = tf.get_variable( 'output_bias', shape=[vocab_size], initializer=tf.zeros_initializer()) dec = tf.reshape(dec, [-1, hidden_size]) logits = tf.matmul(dec, embedding_table, transpose_b=True) logits = tf.reshape( logits, [-1, target_max_seq_length, vocab_size]) logits = tf.nn.bias_add(logits, output_bias) return logits # convert to labels label_ids = tf.concat( [target_ids[:, 1:], tf.zeros([batch_size, 1], dtype=tf.int32)], axis=-1) # forward once if is_training: target_mask = tf.math.equal(target_ids, 0) # (N, T2) logits = _forward( target_ids, target_mask, target_max_seq_length) self.preds['MT'] = tf.argmax(logits, axis=-1) # forward loop else: target_mask_base = tf.zeros([batch_size, 1], dtype=tf.int32) target_ids = tf.ones([batch_size, 1], dtype=tf.int32) * sos_id for cur_length in range(1, target_max_seq_length + 1): target_mask = tf.tile(target_mask_base, [1, cur_length]) logits = _forward(target_ids, target_mask, cur_length) pred_ids = tf.argmax( logits[:, cur_length-1:cur_length, :], axis=-1) pred_ids = tf.cast(pred_ids, tf.int32) target_ids = tf.concat([target_ids, pred_ids], axis=-1) self.preds['MT'] = target_ids[:, 1:] # loss log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=vocab_size) if use_label_smoothing: one_hot_labels = label_smoothing(one_hot_labels) per_token_loss = -tf.reduce_sum( one_hot_labels * log_probs, axis=-1) label_mask = tf.cast(tf.not_equal(label_ids, 0), tf.float32) per_example_loss = \ tf.reduce_sum(per_token_loss * label_mask, axis=-1) / \ tf.reduce_sum(label_mask, axis=-1) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) self.total_loss = tf.reduce_mean(per_example_loss) self.losses['MT'] = per_example_loss