def create_attention_mask_from_input_mask(from_tensor, to_mask): """Create 3D attention mask from a 2D tensor mask. Args: from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. to_mask: int32 Tensor of shape [batch_size, to_seq_length]. Returns: float Tensor of shape [batch_size, from_seq_length, to_seq_length]. """ from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3]) batch_size = from_shape[0] from_seq_length = from_shape[1] to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2) to_seq_length = to_shape[1] to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]), dtype=from_tensor.dtype) # We don't assume that `from_tensor` is a mask (although it could be). We # don't actually care if we attend *from* padding tokens (only *to* padding) # tokens so we create a tensor of all ones. # # `broadcast_ones` = [batch_size, from_seq_length, 1] broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1], dtype=from_tensor.dtype) # Here we broadcast along two dimensions to create the mask. mask = broadcast_ones * to_mask return mask
def scatter_update(sequence, updates, positions): """Scatter-update a sequence. Args: sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor updates: A tensor of size batch_size*seq_len(*depth) positions: A [batch_size, n_positions] tensor Returns: updated_sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor of "sequence" with elements at "positions" replaced by the values at "updates". Updates to index 0 are ignored. If there are duplicated positions the update is only applied once. updates_mask: A [batch_size, seq_len] mask tensor of which inputs were updated. """ shape = tf_utils.get_shape_list(sequence, expected_rank=[2, 3]) depth_dimension = (len(shape) == 3) if depth_dimension: batch_size, seq_len, depth = shape else: batch_size, seq_len = shape depth = 1 sequence = tf.expand_dims(sequence, -1) n_positions = tf_utils.get_shape_list(positions)[1] shift = tf.expand_dims(seq_len * tf.range(batch_size), -1) flat_positions = tf.reshape(positions + shift, [-1, 1]) flat_updates = tf.reshape(updates, [-1, depth]) updates = tf.scatter_nd(flat_positions, flat_updates, [batch_size * seq_len, depth]) updates = tf.reshape(updates, [batch_size, seq_len, depth]) flat_updates_mask = tf.ones([batch_size * n_positions], tf.int32) updates_mask = tf.scatter_nd(flat_positions, flat_updates_mask, [batch_size * seq_len]) updates_mask = tf.reshape(updates_mask, [batch_size, seq_len]) not_first_token = tf.concat([ tf.zeros((batch_size, 1), tf.int32), tf.ones((batch_size, seq_len - 1), tf.int32) ], -1) updates_mask *= not_first_token updates_mask_3d = tf.expand_dims(updates_mask, -1) # account for duplicate positions if sequence.dtype == tf.float32: updates_mask_3d = tf.cast(updates_mask_3d, tf.float32) updates /= tf.maximum(1.0, updates_mask_3d) else: assert sequence.dtype == tf.int32 updates = tf.math.floordiv(updates, tf.maximum(1, updates_mask_3d)) updates_mask = tf.minimum(updates_mask, 1) updates_mask_3d = tf.minimum(updates_mask_3d, 1) updated_sequence = (((1 - updates_mask_3d) * sequence) + (updates_mask_3d * updates)) if not depth_dimension: updated_sequence = tf.squeeze(updated_sequence, -1) return updated_sequence, updates_mask
def call(self, inputs): sources = inputs["inputs"] targets = inputs["targets"] pos_embed = inputs["pos_embed"] mask = inputs["mask"] input_shape = tf_utils.get_shape_list(sources) source_attention_mask = tf.tile(tf.expand_dims(mask, axis=1), [1, input_shape[1], 1]) memory = self._encoder(sources, attention_mask=source_attention_mask, pos_embed=pos_embed) target_shape = tf_utils.get_shape_list(targets) cross_attention_mask = tf.tile(tf.expand_dims(mask, axis=1), [1, target_shape[1], 1]) target_shape = tf.shape(targets) decoded = self._decoder( tf.zeros_like(targets), memory, # TODO(b/199545430): self_attention_mask could be set to None when this # bug is resolved. Passing ones for now. self_attention_mask=tf.ones( (target_shape[0], target_shape[1], target_shape[1])), cross_attention_mask=cross_attention_mask, return_all_decoder_outputs=True, input_pos_embed=targets, memory_pos_embed=pos_embed) return decoded
def call(self, inputs): from_tensor = inputs[0] to_mask = inputs[1] from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3]) batch_size = from_shape[0] from_seq_length = from_shape[1] to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2) to_seq_length = to_shape[1] to_mask = tf.cast( tf.reshape(to_mask, [batch_size, 1, to_seq_length]), dtype=from_tensor.dtype) # We don't assume that `from_tensor` is a mask (although it could be). We # don't actually care if we attend *from* padding tokens (only *to* padding) # tokens so we create a tensor of all ones. # # `broadcast_ones` = [batch_size, from_seq_length, 1] broadcast_ones = tf.ones( shape=[batch_size, from_seq_length, 1], dtype=from_tensor.dtype) # Here we broadcast along two dimensions to create the mask. mask = broadcast_ones * to_mask return mask
def _chunk(hidden_states, window_overlap): """convert into overlapping chunks. Chunk size = 2w, overlap size = w.""" batch_size, seq_length, hidden_dim = get_shape_list(hidden_states) num_output_chunks = 2 * (seq_length // (2 * window_overlap)) - 1 # define frame size and frame stride (similar to convolution) frame_hop_size = window_overlap * hidden_dim frame_size = 2 * frame_hop_size hidden_states = tf.reshape(hidden_states, (batch_size, seq_length * hidden_dim)) # chunk with overlap chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size) if tf.executing_eagerly(): tf.debugging.assert_equal( get_shape_list(chunked_hidden_states), [batch_size, num_output_chunks, frame_size], message= f"Make sure chunking is correctly applied. `Chunked hidden " f"states should have output dimension" f" {[batch_size, frame_size, num_output_chunks]}, but got " f"{get_shape_list(chunked_hidden_states)}.", ) chunked_hidden_states = tf.reshape( chunked_hidden_states, (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim), ) return chunked_hidden_states
def call(self, query: tf.Tensor, key: tf.Tensor): """Implements the forward pass. Args: query: query input tensor shape [batch, query length, hidden size]. key: key input tensor shape [batch, key length, hidden size]. Returns: A tensor in shape of [batch, heads, query length, key length]. """ batch_size, qlen = tf_utils.get_shape_list(query)[:2] klen = tf_utils.get_shape_list(key)[1] context_position = tf.range(qlen)[:, None] memory_position = tf.range(klen)[None, :] relative_position = memory_position - context_position rp_bucket = _relative_position_bucket( relative_position, bidirectional=self.bidirectional, num_buckets=self.relative_attention_num_buckets, max_distance=self.relative_attention_max_distance) values = tf.nn.embedding_lookup(self._relative_attention_bias, rp_bucket) values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen) values = tf.tile(values, [batch_size, 1, 1, 1]) return values
def _mask_invalid_locations(input_tensor, window_overlap): # create correct upper triangle bool mask mask_2d_upper = tf.reverse( tf.linalg.band_part( tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0), axis=[0], ) # pad to full matrix padding = tf.convert_to_tensor( [[0, get_shape_list(input_tensor)[1] - window_overlap], [0, get_shape_list(input_tensor)[3] - window_overlap - 1]]) # create lower mask mask_2d = tf.pad(mask_2d_upper, padding) # combine with upper mask mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1]) # broadcast to full matrix mask_4d = tf.tile(mask_2d[None, :, None, :], (get_shape_list(input_tensor)[0], 1, 1, 1)) # inf tensor used for masking inf_tensor = -float("inf") * tf.ones_like(input_tensor) # mask input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor) return input_tensor
def _concat_with_global_key_attn_probs( self, attn_scores, key_vectors, query_vectors, max_num_global_attn_indices, is_index_global_attn_nonzero, is_local_index_global_attn_nonzero, is_local_index_no_global_attn_nonzero, ): batch_size = get_shape_list(key_vectors)[0] # select global key vectors global_key_vectors = tf.gather_nd(key_vectors, is_index_global_attn_nonzero) # create only global key vectors key_vectors_only_global = tf.scatter_nd( is_local_index_global_attn_nonzero, global_key_vectors, shape=( batch_size, max_num_global_attn_indices, self._num_heads, self._key_dim, ), ) # (batch_size, seq_len, num_heads, max_num_global_attn_indices) attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors, key_vectors_only_global) # (batch_size, max_num_global_attn_indices, seq_len, num_heads) attn_probs_from_global_key_trans = tf.transpose( attn_probs_from_global_key, (0, 3, 1, 2)) mask_shape = (get_shape_list( is_local_index_no_global_attn_nonzero)[0], ) + tuple( get_shape_list(attn_probs_from_global_key_trans)[-2:]) mask = tf.ones(mask_shape) * -10000.0 mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype) # scatter mask attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update( attn_probs_from_global_key_trans, is_local_index_no_global_attn_nonzero, mask, ) # (batch_size, seq_len, num_heads, max_num_global_attn_indices) attn_probs_from_global_key = tf.transpose( attn_probs_from_global_key_trans, (0, 2, 3, 1)) # concat to attn_probs # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) attn_scores = tf.concat((attn_probs_from_global_key, attn_scores), axis=-1) return attn_scores
def _pad_to_window_size( self, word_ids, mask, type_ids, word_embeddings, pad_token_id, ): # padding attention_window = max(self._attention_window) assert (attention_window % 2 == 0), ('`attention_window` should be an even value.' f'Given {attention_window}') input_shape = get_shape_list( word_ids) if word_ids is not None else get_shape_list( word_embeddings) batch_size, seq_len = input_shape[:2] if seq_len is not None: padding_len = (attention_window - seq_len % attention_window) % attention_window else: padding_len = 0 paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]]) if word_ids is not None: word_ids = tf.pad(word_ids, paddings, constant_values=pad_token_id) if word_embeddings is not None: def pad_embeddings(): word_ids_padding = tf.fill((batch_size, padding_len), self.pad_token_id) word_embeddings_padding = self._embedding_layer( word_ids_padding) return tf.concat([word_embeddings, word_embeddings_padding], axis=-2) word_embeddings = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings, lambda: word_embeddings) mask = tf.pad( mask, paddings, constant_values=False) # no attention on the padding tokens token_type_ids = tf.pad( type_ids, paddings, constant_values=0) # pad with token_type_id = 0 return ( padding_len, word_ids, mask, token_type_ids, word_embeddings, )
def _gather_indexes(self, sequence_tensor, positions): """Gathers the vectors at the specific positions. Args: sequence_tensor: Sequence output of shape (`batch_size`, `seq_length`, `num_hidden`) where `num_hidden` is number of hidden units. positions: Positions ids of tokens in batched sequences. Returns: Sequence tensor of shape (batch_size * num_predictions, num_hidden). """ sequence_shape = tf_utils.get_shape_list(sequence_tensor, name='sequence_output_tensor') batch_size, seq_length, width = sequence_shape flat_offsets = tf.reshape( tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) flat_positions = tf.reshape(positions + flat_offsets, [-1]) flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width]) output_tensor = tf.gather(flat_sequence_tensor, flat_positions) return output_tensor
def call(self, input_embeddings: tf.Tensor, input_mask: tf.Tensor) -> Dict[str, tf.Tensor]: batch_size, seq_len, embedding_dim = tf_utils.get_shape_list( input_embeddings, expected_rank=3) example_ids = None reduced_batch_size = batch_size // self.pack_sequences packed_seq_len = self.pack_sequences * seq_len packed_embeddings = tf.reshape( input_embeddings, [reduced_batch_size, packed_seq_len, embedding_dim]) input_mask = tf.reshape(input_mask, [reduced_batch_size, packed_seq_len]) example_ids = 1 + tf.range(self.pack_sequences) # Shape: [batch_size, seq_len, pack_sequences]. example_ids = tf.tile(example_ids[None, :, None], [reduced_batch_size, 1, seq_len]) example_ids = tf.reshape(example_ids, [reduced_batch_size, packed_seq_len]) example_ids = tf.where(tf.math.equal(input_mask, 0), tf.zeros_like(example_ids), example_ids) packing_mask = _packing_mask(example_ids, example_ids, dtype=tf.bool) attention_mask = self_attention_mask.get_mask(packed_embeddings, input_mask, dtype=tf.bool) combined_attention_mask = tf.cast( tf.math.logical_and(attention_mask, packing_mask), tf.float32) return dict(packed_embeddings=packed_embeddings, combined_attention_mask=combined_attention_mask)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: """Interface to compute losses. Refer to base_task.Task.build_losses.""" del labels left_logits = model_outputs['left_logits'] right_logits = model_outputs['right_logits'] batch_size = tf_utils.get_shape_list(left_logits, name='batch_size')[0] ranking_labels = tf.range(batch_size) loss = tf_utils.safe_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=ranking_labels, logits=left_logits)) if self.task_config.model.bidirectional: right_rank_loss = tf_utils.safe_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=ranking_labels, logits=right_logits)) loss += right_rank_loss return tf.reduce_mean(loss)
def call(self, inputs): """Implements call() for the layer.""" length = self._length if inputs is None and length is None: raise ValueError( "If inputs is None, `length` must be set in " "RelativePositionEmbedding().") if inputs is not None: input_shape = tf_utils.get_shape_list(inputs) if length is not None and length != input_shape[1]: raise ValueError( "If inputs is not None, `length` must equal to input_shape[1]." ) length = input_shape[1] position = tf.cast(tf.range(length), tf.float32) num_timescales = self._hidden_size // 2 min_timescale, max_timescale = self._min_timescale, self._max_timescale log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / (tf.cast(num_timescales, tf.float32) - 1)) inv_timescales = min_timescale * tf.exp( tf.cast(tf.range(num_timescales), tf.float32) * -log_timescale_increment) scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0) position_embeddings = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) return position_embeddings
def call(self, inputs): """Implements call() for the layer.""" input_shape = tf_utils.get_shape_list(inputs) flat_input = tf.reshape(inputs, [-1]) output = tf.gather(self.embeddings, flat_input) output = tf.reshape(output, input_shape + [self.embedding_size]) return output
def gather_indexes(sequence_tensor, positions): """Gathers the vectors at the specific positions. Args: sequence_tensor: Sequence output of `BertModel` layer of shape (`batch_size`, `seq_length`, num_hidden) where num_hidden is number of hidden units of `BertModel` layer. positions: Positions ids of tokens in sequence to mask for pretraining of with dimension (batch_size, max_predictions_per_seq) where `max_predictions_per_seq` is maximum number of tokens to mask out and predict per each sequence. Returns: Masked out sequence tensor of shape (batch_size * max_predictions_per_seq, num_hidden). """ sequence_shape = tf_utils.get_shape_list(sequence_tensor, name='sequence_output_tensor') batch_size = sequence_shape[0] seq_length = sequence_shape[1] width = sequence_shape[2] flat_offsets = tf.keras.backend.reshape( tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) flat_positions = tf.keras.backend.reshape(positions + flat_offsets, [-1]) flat_sequence_tensor = tf.keras.backend.reshape( sequence_tensor, [batch_size * seq_length, width]) output_tensor = tf.gather(flat_sequence_tensor, flat_positions) return output_tensor
def _gather_indexes(self, sequence_tensor, positions): """Gathers the vectors at the specific positions. Args: sequence_tensor: Sequence output of `BertModel` layer of shape (`batch_size`, `seq_length`, num_hidden) where num_hidden is number of hidden units of `BertModel` layer. positions: Positions ids of tokens in sequence to mask for pretraining of with dimension (batch_size, num_predictions) where `num_predictions` is maximum number of tokens to mask out and predict per each sequence. Returns: Masked out sequence tensor of shape (batch_size * num_predictions, num_hidden). """ sequence_shape = tf_utils.get_shape_list(sequence_tensor, name='sequence_output_tensor') batch_size, seq_length, width = sequence_shape # positions 为遮蔽的单词的 id, 形状 batch,num_predictions # 获取被遮蔽单词,在批量序列展平后,对应的索引 flat_offsets = tf.reshape( tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) flat_positions = tf.reshape(positions + flat_offsets, [-1]) # 将输入展平, flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width]) output_tensor = tf.gather(flat_sequence_tensor, flat_positions) # 获取被遮蔽单词 batch*num_predictions, width return output_tensor
def call(self, inputs): """Implements call() for the layer.""" unpacked_inputs = tf_utils.unpack_inputs(inputs) word_embeddings = unpacked_inputs[0] token_type_ids = unpacked_inputs[1] input_shape = tf_utils.get_shape_list(word_embeddings, expected_rank=3) batch_size = input_shape[0] seq_length = input_shape[1] width = input_shape[2] output = word_embeddings if self.use_type_embeddings: flat_token_type_ids = tf.reshape(token_type_ids, [-1]) token_type_embeddings = tf.gather(self.type_embeddings, flat_token_type_ids) token_type_embeddings = tf.reshape(token_type_embeddings, [batch_size, seq_length, width]) output += token_type_embeddings if self.use_position_embeddings: position_embeddings = tf.expand_dims(tf.slice( self.position_embeddings, [0, 0], [seq_length, width]), axis=0) output += position_embeddings output = self.output_layer_norm(output) output = self.output_dropout(output) return output
def sample_from_softmax(logits, disallow=None): """Implement softmax sampling using gumbel softmax trick. Args: logits: A [batch_size, num_token_predictions, vocab_size] tensor indicating the generator output logits for each masked position. disallow: If `None`, we directly sample tokens from the logits. Otherwise, this is a tensor of size [batch_size, num_token_predictions, vocab_size] indicating the true word id in each masked position. Returns: sampled_tokens: A [batch_size, num_token_predictions, vocab_size] one hot tensor indicating the sampled word id in each masked position. """ if disallow is not None: logits -= 1000.0 * disallow uniform_noise = tf.random.uniform(tf_utils.get_shape_list(logits), minval=0, maxval=1) gumbel_noise = -tf.math.log(-tf.math.log(uniform_noise + 1e-9) + 1e-9) # Here we essentially follow the original paper and use temperature 1.0 for # generator output logits. sampled_tokens = tf.one_hot( tf.argmax(tf.nn.softmax(logits + gumbel_noise), -1, output_type=tf.int32), logits.shape[-1]) return sampled_tokens
def call(self, input_tensor, unpooled_len=0): if self.pool_size == 1: return input_tensor batch_size, seq_len = tf_utils.get_shape_list(input_tensor, expected_rank=2) # reshape tensor in order to use tf.nn.pool reshaped_tensor = tf.reshape(input_tensor, [batch_size, seq_len, 1]) if self.nocls: tensor_to_pool = reshaped_tensor[:, 1:, :] else: tensor_to_pool = reshaped_tensor if unpooled_len > 0: tensor_to_pool = tensor_to_pool[:, :-unpooled_len, :] pooled_tensor = tf.nn.max_pool( tensor_to_pool, ksize=self.pool_size, strides=self.pool_size, padding='SAME') if self.nocls: pooled_tensor = tf.concat([reshaped_tensor[:, 0:1, :], pooled_tensor], axis=1) if unpooled_len > 0: pooled_tensor = tf.concat( [pooled_tensor, reshaped_tensor[:, -unpooled_len:, :]], axis=1) pooled_tensor = tf.reshape(pooled_tensor, [batch_size, -1]) return pooled_tensor
def symbols_to_logits_fn(ids, i, cache): """Generate logits for next potential IDs. Args: ids: Current decoded sequences. int tensor with shape `(batch_size * beam_size, i + 1)`. i: Loop index. cache: Dictionary of values storing the encoder output, encoder-decoder attention bias, and previous decoder attention values. Returns: Tuple of (logits with shape `(batch_size * beam_size, vocab_size)`, updated cache values) """ # Set decoder input to the last generated IDs decoder_input = ids[:, -1:] # Preprocess decoder input by getting embeddings and adding timing signal. # decoder_input = self.embedding_softmax_layer(decoder_input) source_decoder_input = decoder_input decoder_input = self.embedding_lookup(decoder_input) embedding_mask = tf.cast(tf.not_equal(source_decoder_input, 0), decoder_input.dtype) decoder_input *= tf.expand_dims(embedding_mask, -1) decoder_input += timing_signal[i] if self._padded_decode: # indexing does not work on TPU. bias_shape = decoder_self_attention_mask.shape.as_list() self_attention_mask = tf.slice( decoder_self_attention_mask, [0, i, 0], [bias_shape[0], 1, bias_shape[2]]) else: self_attention_mask = decoder_self_attention_mask[:, i:i + 1, :i + 1] decoder_shape = tf_utils.get_shape_list(decoder_input, expected_rank=3) batch_size = decoder_shape[0] decoder_length = decoder_shape[1] self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1]) attention_mask = cache.get("encoder_decoder_attention_mask") attention_mask = tf.tile(attention_mask, [1, decoder_length, 1]) decoder_outputs = self.decoder_layer( decoder_input, cache.get("encoder_outputs"), self_attention_mask=self_attention_mask, cross_attention_mask=attention_mask, cache=cache, decode_loop_step=i if self._padded_decode else None) decoder_outputs = tf.cast(decoder_outputs, dtype=self.compute_dtype) logits = self._embedding_linear(self.embedding_lookup.embeddings, decoder_outputs) logits = tf.squeeze(logits, axis=[1]) return logits, cache
def sample_k_from_softmax(logits, k, disallow=None, use_topk=False): """Implement softmax sampling using gumbel softmax trick to select k items. Args: logits: A [batch_size, num_token_predictions, vocab_size] tensor indicating the generator output logits for each masked position. k: Number of samples disallow: If `None`, we directly sample tokens from the logits. Otherwise, this is a tensor of size [batch_size, num_token_predictions, vocab_size] indicating the true word id in each masked position. use_topk: Whether to use tf.nn.top_k or using iterative approach where the latter is empirically faster. Returns: sampled_tokens: A [batch_size, num_token_predictions, k] tensor indicating the sampled word id in each masked position. """ if use_topk: if disallow is not None: logits -= 10000.0 * disallow uniform_noise = tf.random.uniform(tf_utils.get_shape_list(logits), minval=0, maxval=1) gumbel_noise = -tf.math.log(-tf.math.log(uniform_noise + 1e-9) + 1e-9) _, sampled_tokens = tf.nn.top_k(logits + gumbel_noise, k=k, sorted=False) else: sampled_tokens_list = [] vocab_size = tf_utils.get_shape_list(logits)[-1] if disallow is not None: logits -= 10000.0 * disallow uniform_noise = tf.random.uniform(tf_utils.get_shape_list(logits), minval=0, maxval=1) gumbel_noise = -tf.math.log(-tf.math.log(uniform_noise + 1e-9) + 1e-9) logits += gumbel_noise for _ in range(k): token_ids = tf.argmax(logits, -1, output_type=tf.int32) sampled_tokens_list.append(token_ids) logits -= 10000.0 * tf.one_hot( token_ids, depth=vocab_size, dtype=tf.float32) sampled_tokens = tf.stack(sampled_tokens_list, -1) return sampled_tokens
def call(self, inputs): """Implements call() for the layer.""" input_shape = tf_utils.get_shape_list(inputs, expected_rank=3) if self._use_dynamic_slicing: position_embeddings = self._position_embeddings[:input_shape[1], :] else: position_embeddings = self._position_embeddings return tf.broadcast_to(position_embeddings, input_shape)
def remove_sos_from_seq(seq, pad_token_id): """Remove the start sequence token while keeping seq length.""" batch_size, seq_len = tf_utils.get_shape_list(seq, expected_rank=2) # remove <s> targets = seq[:, 1:] # pad pad_ids = tf.ones([batch_size], tf.int32) * pad_token_id targets = tf.concat([targets, tf.expand_dims(pad_ids, axis=1)], axis=1) tf.assert_equal(tf.shape(targets), (batch_size, seq_len)) return targets
def call(self, inputs): """Implements call() for the layer.""" input_shape = tf_utils.get_shape_list(inputs) # 将 betch,seq_len 的数据展平,便于计算 flat_input = tf.reshape(inputs, [-1]) output = tf.gather(self.embeddings, flat_input) # 再还原成 batch 数据 output = tf.reshape(output, input_shape + [self.embedding_size]) return output
def call(self, inputs, length=None): """Implements call() for the layer. Args: inputs: An tensor whose second dimension will be used as `length`. If `None`, the other `length` argument must be specified. length: An optional integer specifying the number of positions. If both `inputs` and `length` are spcified, `length` must be equal to the second dimension of `inputs`. Returns: A tensor in shape of [length, hidden_size]. """ if inputs is None and length is None: raise ValueError("If inputs is None, `length` must be set in " "RelativePositionEmbedding().") if inputs is not None: input_shape = tf_utils.get_shape_list(inputs) if length is not None and length != input_shape[1]: raise ValueError( "If inputs is not None, `length` must equal to input_shape[1]." ) length = input_shape[1] # range(10) position = tf.cast(tf.range(length), tf.float32) # e.g. : 8 // 2 num_timescales = self._hidden_size // 2 # 1.0, 1.0e4 min_timescale, max_timescale = self._min_timescale, self._max_timescale # log(1.e4 / 1.) / (4 - 1) = 3.07 log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / (tf.cast(num_timescales, tf.float32) - 1)) # 1.0 * exp( [0.0, 1.0, 2.0, 3.0 ] * -3.07 ) inv_timescales = min_timescale * tf.exp( tf.cast(tf.range(num_timescales), tf.float32) * -log_timescale_increment) # (length,1) * (1,num_timescale) scaled_time = tf.expand_dims(position, 1) * tf.expand_dims( inv_timescales, 0) # 分别 sin 和 cos 操作,然后拼接成 hidden_size 长的向量 position_embeddings = tf.concat( [tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) # r = log( max_ - min_) / (hidden_size / 2) # a = [ [0], [1], [2]...[len] ] * [ e^ ( r * [0, 1, 2, ... hidden_size/2 ] ) ] # o = concat( sin(a), cos(a) ) --> len, hidden_size return position_embeddings
def _pad_and_transpose_last_two_dims(hidden_states_padded, paddings): """Pads rows and then flips rows and columns.""" hidden_states_padded = tf.pad( hidden_states_padded, paddings ) # padding value is not important because it will be overwritten batch_size, chunk_size, seq_length, hidden_dim = get_shape_list( hidden_states_padded) hidden_states_padded = tf.reshape( hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length)) return hidden_states_padded
def call(self, input_positions): """Implements call() for the layer.""" batch_size, seq_len = tf_utils.get_shape_list( input_positions, expected_rank=2) flat_positions = tf.reshape(input_positions, [-1]) position_embeddings = tf.gather(self._position_embeddings, flat_positions) position_embeddings = tf.reshape(position_embeddings, [batch_size, seq_len, self.embed_dim]) if self._use_dynamic_slicing: position_embeddings = position_embeddings[:, :seq_len, :] return position_embeddings
def call(self, target_embedding): lm_data = self.dense(target_embedding) lm_data = self.layer_norm(lm_data) lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True) logits = tf.nn.bias_add(lm_data, self.bias) masked_positions_shape = tf_utils.get_shape_list( target_embedding, name='masked_positions_tensor') logits = tf.reshape(logits, [-1, masked_positions_shape[1], self._vocab_size]) if self._output_type == 'logits': return logits return tf.nn.log_softmax(logits)
def _parse_inputs(self, inputs): """Parses the `call` inputs and returns an uniformed output.""" sources = inputs.get("inputs", None) input_mask = inputs.get("input_masks", None) embedded = inputs.get("embedded_inputs", None) if sources is None and embedded is not None: embedded_inputs = embedded boolean_mask = input_mask input_shape = tf_utils.get_shape_list(embedded, expected_rank=3) source_dtype = embedded.dtype elif sources is not None: embedded_inputs = self.embedding_lookup(sources) boolean_mask = tf.not_equal(sources, 0) input_shape = tf_utils.get_shape_list(sources, expected_rank=2) source_dtype = sources.dtype else: raise KeyError( "The call method expects either `inputs` or `embedded_inputs` and " "`input_masks` as input features.") return embedded_inputs, boolean_mask, input_shape, source_dtype
def call(self, inputs): """Implements call() for the layer.""" if self._use_dynamic_slicing: input_shape = tf_utils.get_shape_list(inputs, expected_rank=3) seq_length = input_shape[1] width = input_shape[2] position_embeddings = tf.expand_dims( tf.slice(self._position_embeddings, [0, 0], [seq_length, width]), axis=0) else: position_embeddings = tf.expand_dims(self._position_embeddings, axis=0) return position_embeddings