def test_shift_elements_right_2d(self): tensor = tf.constant([ [1, 2, 3, 4], # [5, 6, 7, 8], # [9, 10, 11, 12], # ]) self.assertAllEqual( [ [0, 1, 2, 3], # [0, 5, 6, 7], # [0, 9, 10, 11], # ], tensor_utils.shift_elements_right(tensor)) self.assertAllEqual( [ [3, 4, -1, -1], # [7, 8, -1, -1], # [11, 12, -1, -1], # ], tensor_utils.shift_elements_right(tensor, amount=-2, pad_value=-1)) self.assertAllEqual( [ [0, 0, 0, 0], # [0, 0, 0, 0], # [1, 2, 3, 4], # ], tensor_utils.shift_elements_right(tensor, axis=-2, amount=2)) self.assertAllEqual( tensor, tensor_utils.shift_elements_right(tensor, amount=0))
def test_shift_elements_right_1d(self): tensor = tf.constant([5, 4, 3, 2, 1]) self.assertAllEqual([0, 5, 4, 3, 2], tensor_utils.shift_elements_right(tensor)) self.assertAllEqual([0, 0, 0, 5, 4], tensor_utils.shift_elements_right(tensor, amount=3)) self.assertAllEqual([3, 2, 1, 0, 0], tensor_utils.shift_elements_right(tensor, amount=-2)) self.assertAllEqual([3, 2, 1, -1, -1], tensor_utils.shift_elements_right(tensor, amount=-2, pad_value=-1)) self.assertAllEqual([0, 0, 0, 0, 0], tensor_utils.shift_elements_right(tensor, amount=10)) self.assertAllEqual( tensor, tensor_utils.shift_elements_right(tensor, amount=0)) with self.assertRaises(ValueError): tensor_utils.shift_elements_right(tensor, axis=1)
def make_att_mask_from_breakpoints(att_breakpoints: tf.Tensor, use_starting_breakpoints: bool = False, name: Optional[Text] = None) -> tf.Tensor: """Makes self-attention mask from attention breakpoints. Each attention breakpoint marks the end of a segment by default (or the start if `use_starting_breakpoints` is True), and the resulting mask prevents attention across different segments. Args: att_breakpoints: <int32>[batch_size, seq_len] Tensor containing only 0 and 1 values, where each "1" marks the end of a segment (or the start, depending on `use_starting_breakpoints`). use_starting_breakpoints: If True, breakpoints represent starts of segments rather than ends of segments. Default False. name: A name for the operation (optional). Returns: <int32>[batch_size, seq_len, seq_len] attention mask. """ with tf.name_scope(name or 'make_att_mask_from_breakpoints'): att_breakpoints = tf.convert_to_tensor(att_breakpoints) if att_breakpoints.shape.rank != 2: raise ValueError('`att_breakpoints` must be a 2-D tensor.') if not use_starting_breakpoints: att_breakpoints = tensor_utils.shift_elements_right( att_breakpoints, axis=-1, amount=1) segment_ids = tf.cumsum(att_breakpoints, axis=1) return make_segmented_att_mask(segment_ids)
def make_local_att_mask_from_breakpoints( att_breakpoints: tf.Tensor, local_radius: int, use_starting_breakpoints: bool = False, name: Optional[Text] = None) -> tf.Tensor: """Makes local self-attention mask from attention breakpoints. Each attention breakpoint marks the end of a segment by default (or the start if `use_starting_breakpoints` is True), and the resulting mask prevents attention across different segments. The result can be used as `l2l_att_mask` in `layers.GlobalLocalTransformerLayers` for example. Args: att_breakpoints: <int32>[batch_size, seq_len] Tensor containing only 0 and 1 values, where each "1" marks the end of a segment (or the start, depending on `use_starting_breakpoints`). local_radius: The local radius as expected by `layers.GlobalLocalTransformerLayers`. Must be positive. use_starting_breakpoints: If True, breakpoints represent starts of segments rather than ends of segments. Default False. name: A name for the operation (optional). Returns: <int32>[batch_size, seq_len, 2*local_radius + 1] attention mask. """ with tf.name_scope(name or 'make_local_att_mask_from_breakpoints'): att_breakpoints = tf.convert_to_tensor(att_breakpoints) if att_breakpoints.shape.rank != 2: raise ValueError('`att_breakpoints` must be a 2-D tensor.') if not use_starting_breakpoints: att_breakpoints = tensor_utils.shift_elements_right( att_breakpoints, axis=-1, amount=1) # [batch_size, seq_len] segment_ids = tf.cumsum(att_breakpoints, axis=1) return make_local_segmented_att_mask(segment_ids, local_radius)
def _build_model(model_config, features, is_training, flags): """Build an ETC model for OpenKP.""" global_embedding_adder = None long_embedding_adder = None # Create `global_embedding_adder` if using visual features. if flags.use_visual_features_in_global or flags.use_visual_features_in_long: global_embedding_adder = _create_global_visual_feature_embeddings( model_config, features, flags) if flags.use_visual_features_in_long: # Create `long_embedding_adder` based on `global_embedding_adder` long_embedding_adder = gather_global_embeddings_to_long( global_embedding_adder, features['long_vdom_idx']) if not flags.use_visual_features_in_global: global_embedding_adder = None model = modeling.EtcModel( config=model_config, is_training=is_training, use_one_hot_relative_embeddings=flags.use_tpu) model_inputs = dict( token_ids=features['long_token_ids'], global_token_ids=features['global_token_ids'], long_embedding_adder=long_embedding_adder, global_embedding_adder=global_embedding_adder) for field in attr.fields(input_utils.GlobalLocalTransformerSideInputs): model_inputs[field.name] = features[field.name] long_output, _ = model(**model_inputs) word_embeddings_unnormalized = batch_segment_sum_embeddings( long_embeddings=long_output, long_word_idx=features['long_word_idx'], long_input_mask=features['long_input_mask']) word_emb_layer_norm = tf.keras.layers.LayerNormalization( axis=-1, epsilon=1e-12, name='word_emb_layer_norm') word_embeddings = word_emb_layer_norm(word_embeddings_unnormalized) ngram_logit_list = [] for i in range(flags.kp_max_length): conv = tf.keras.layers.Conv1D( filters=model_config.hidden_size, kernel_size=i + 1, padding='valid', activation=tensor_utils.get_activation('gelu'), kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=0.02 / math.sqrt(i + 1)), name=f'{i + 1}gram_conv') layer_norm = tf.keras.layers.LayerNormalization( axis=-1, epsilon=1e-12, name=f'{i + 1}gram_layer_norm') logit_dense = tf.keras.layers.Dense( units=1, activation=None, use_bias=False, kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), name=f'logit_dense{i}') # [batch_size, long_max_length - i] unpadded_logits = tf.squeeze( logit_dense(layer_norm(conv(word_embeddings))), axis=-1) # Pad to the right to get back to `long_max_length`. padded_logits = tf.pad(unpadded_logits, paddings=[[0, 0], [0, i]]) # Padding logits should be ignored, so we make a large negative mask adder # for them. shifted_word_mask = tf.cast( tensor_utils.shift_elements_right( features['long_word_input_mask'], axis=-1, amount=-i), dtype=padded_logits.dtype) mask_adder = -10000.0 * (1.0 - shifted_word_mask) ngram_logit_list.append(padded_logits * shifted_word_mask + mask_adder) # [batch_size, kp_max_length, long_max_length] ngram_logits = tf.stack(ngram_logit_list, axis=1) extra_model_losses = model.losses return ngram_logits, extra_model_losses
def test_shift_elements_right_3d(self): tensor = tf.constant([ [ [1, -1], # [2, -2], # [3, -3], # ], # [ [4, -4], # [5, -5], # [6, -6], # ], # ]) self.assertAllEqual( [ [ [-1, 0], # [-2, 0], # [-3, 0], # ], # [ [-4, 0], # [-5, 0], # [-6, 0], # ], # ], tensor_utils.shift_elements_right(tensor, amount=-1)) self.assertAllEqual( [ [ [-1, -1], # [-1, -1], # [-1, -1], # ], # [ [1, -1], # [2, -2], # [3, -3], # ], # ], tensor_utils.shift_elements_right(tensor, axis=0, pad_value=-1)) self.assertAllEqual( [ [ [0, 0], # [0, 0], # [1, -1], # ], # [ [0, 0], # [0, 0], # [4, -4], # ], # ], tensor_utils.shift_elements_right(tensor, axis=1, amount=2)) self.assertAllEqual( tensor, tensor_utils.shift_elements_right(tensor, axis=2, amount=0))