def build_model(etc_model_config: modeling.EtcConfig, features: Dict[str, tf.Tensor], flat_sequence: bool, is_training: bool, answer_encoding_method: str, use_tpu: bool, use_wordpiece: bool): """Build the ETC HotpotQA model.""" long_token_ids = features["long_token_ids"] long_sentence_ids = features["long_sentence_ids"] long_paragraph_ids = features["long_paragraph_ids"] long_paragraph_breakpoints = features["long_paragraph_breakpoints"] long_token_type_ids = features["long_token_type_ids"] global_token_ids = features["global_token_ids"] global_paragraph_breakpoints = features["global_paragraph_breakpoints"] global_token_type_ids = features["global_token_type_ids"] model = modeling.EtcModel(config=etc_model_config, is_training=is_training, use_one_hot_relative_embeddings=use_tpu) model_inputs = dict(token_ids=long_token_ids, global_token_ids=global_token_ids, segment_ids=long_token_type_ids, global_segment_ids=global_token_type_ids) cls_token_id = (generate_tf_examples_lib. SENTENCEPIECE_DEFAULT_GLOBAL_TOKEN_IDS["CLS_TOKEN_ID"]) if use_wordpiece: cls_token_id = (generate_tf_examples_lib. WORDPIECE_DEFAULT_GLOBAL_TOKEN_IDS["CLS_TOKEN_ID"]) model_inputs.update( qa_input_utils.make_global_local_transformer_side_inputs( long_paragraph_breakpoints=long_paragraph_breakpoints, long_paragraph_ids=long_paragraph_ids, long_sentence_ids=long_sentence_ids, global_paragraph_breakpoints=global_paragraph_breakpoints, local_radius=etc_model_config.local_radius, relative_pos_max_distance=etc_model_config. relative_pos_max_distance, use_hard_g2l_mask=etc_model_config.use_hard_g2l_mask, ignore_hard_g2l_mask=tf.cast(tf.equal(global_token_ids, cls_token_id), dtype=long_sentence_ids.dtype), flat_sequence=flat_sequence, use_hard_l2g_mask=etc_model_config.use_hard_l2g_mask).to_dict( exclude_none_values=True)) long_output, global_output = model(**model_inputs) batch_size, long_seq_length, long_hidden_size = tensor_utils.get_shape_list( long_output, expected_rank=3) _, global_seq_length, global_hidden_size = tensor_utils.get_shape_list( global_output, expected_rank=3) long_output_matrix = tf.reshape( long_output, [batch_size * long_seq_length, long_hidden_size]) global_output_matrix = tf.reshape( global_output, [batch_size * global_seq_length, global_hidden_size]) # Get the logits for the supporting facts predictions. supporting_facts_output_weights = tf.get_variable( "supporting_facts_output_weights", [1, global_hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) supporting_facts_output_bias = tf.get_variable( "supporting_facts_output_bias", [1], initializer=tf.zeros_initializer()) supporting_facts_logits = tf.matmul(global_output_matrix, supporting_facts_output_weights, transpose_b=True) supporting_facts_logits = tf.nn.bias_add(supporting_facts_logits, supporting_facts_output_bias) supporting_facts_logits = tf.reshape(supporting_facts_logits, [batch_size, global_seq_length]) # Get the logits for the answer type prediction. num_answer_types = 3 # SPAN, YES, NO answer_type_output_weights = tf.get_variable( "answer_type_output_weights", [num_answer_types, global_hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) answer_type_output_bias = tf.get_variable( "answer_type_output_bias", [num_answer_types], initializer=tf.zeros_initializer()) answer_type_logits = tf.matmul(global_output[:, 0, :], answer_type_output_weights, transpose_b=True) answer_type_logits = tf.nn.bias_add(answer_type_logits, answer_type_output_bias) extra_model_losses = model.losses if answer_encoding_method == "span": # Get the logits for the begin and end indices. answer_span_output_weights = tf.get_variable( "answer_span_output_weights", [2, long_hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) answer_span_output_bias = tf.get_variable( "answer_span_output_bias", [2], initializer=tf.zeros_initializer()) answer_span_logits = tf.matmul(long_output_matrix, answer_span_output_weights, transpose_b=True) answer_span_logits = tf.nn.bias_add(answer_span_logits, answer_span_output_bias) answer_span_logits = tf.reshape(answer_span_logits, [batch_size, long_seq_length, 2]) answer_span_logits = tf.transpose(answer_span_logits, [2, 0, 1]) answer_begin_logits, answer_end_logits = tf.unstack(answer_span_logits, axis=0) return (supporting_facts_logits, (answer_begin_logits, answer_end_logits), answer_type_logits, extra_model_losses) else: # Get the logits for the answer BIO encodings. answer_bio_output_weights = tf.get_variable( "answer_bio_output_weights", [3, long_hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) answer_type_output_bias = tf.get_variable( "answer_bio_output_bias", [3], initializer=tf.zeros_initializer()) answer_bio_logits = tf.matmul(long_output_matrix, answer_bio_output_weights, transpose_b=True) answer_bio_logits = tf.nn.bias_add(answer_bio_logits, answer_type_output_bias) answer_bio_logits = tf.reshape(answer_bio_logits, [batch_size, long_seq_length, 3]) return (supporting_facts_logits, answer_bio_logits, answer_type_logits, extra_model_losses)
def build_model(etc_model_config, features, is_training, flags): """Build an ETC model.""" token_ids = features["token_ids"] global_token_ids = features["global_token_ids"] model = modeling.EtcModel(config=etc_model_config, is_training=is_training, use_one_hot_relative_embeddings=flags.use_tpu) model_inputs = dict(token_ids=token_ids, global_token_ids=global_token_ids) for field in attr.fields(input_utils.GlobalLocalTransformerSideInputs): if field.name in features: model_inputs[field.name] = features[field.name] # Get the logits for the start and end predictions. l_final_hidden, _ = model(**model_inputs) l_final_hidden_shape = tensor_utils.get_shape_list(l_final_hidden, expected_rank=3) batch_size = l_final_hidden_shape[0] l_seq_length = l_final_hidden_shape[1] hidden_size = l_final_hidden_shape[2] num_answer_types = 5 # NULL, YES, NO, LONG, SHORT # We add a dense layer to the long output: l_output_weights = tf.get_variable( "cls/nq/long_output_weights", [4, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) l_output_bias = tf.get_variable("cls/nq/long_output_bias", [4], initializer=tf.zeros_initializer()) l_final_hidden_matrix = tf.reshape( l_final_hidden, [batch_size * l_seq_length, hidden_size]) l_logits = tf.matmul(l_final_hidden_matrix, l_output_weights, transpose_b=True) l_logits = tf.nn.bias_add(l_logits, l_output_bias) l_logits = tf.reshape(l_logits, [batch_size, l_seq_length, 4]) if flags.mask_long_output: # Mask out invalid SA/LA start/end positions: # 1) find the SEP and CLS tokens: long_sep = tf.cast(tf.equal(token_ids, flags.sep_tok_id), tf.int32) long_not_sep = 1 - long_sep long_cls = tf.cast(tf.equal(token_ids, flags.cls_tok_id), tf.int32) # 2) accum sum the SEPs, and the only possible answers are those with sum # equal to 1 (except SEPs) and the CLS position l_mask = tf.cast(tf.equal(tf.cumsum(long_sep, axis=-1), 1), tf.int32) l_mask = 1 - ((l_mask * long_not_sep) + long_cls) # 3) apply the mask to the logits l_mask = tf.expand_dims(tf.cast(l_mask, tf.float32) * -10E8, 2) l_logits = tf.math.add(l_logits, l_mask) # Get the logits for the answer type prediction. answer_type_output_layer = l_final_hidden[:, 0, :] answer_type_hidden_size = answer_type_output_layer.shape[-1].value answer_type_output_weights = tf.get_variable( "answer_type_output_weights", [num_answer_types, answer_type_hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) answer_type_output_bias = tf.get_variable( "answer_type_output_bias", [num_answer_types], initializer=tf.zeros_initializer()) answer_type_logits = tf.matmul(answer_type_output_layer, answer_type_output_weights, transpose_b=True) answer_type_logits = tf.nn.bias_add(answer_type_logits, answer_type_output_bias) extra_model_losses = model.losses l_logits = tf.transpose(l_logits, [2, 0, 1]) l_unstacked_logits = tf.unstack(l_logits, axis=0) return ([l_unstacked_logits[i] for i in range(4)], answer_type_logits, extra_model_losses)
def process_model_output(model_config, mode, global_output_tensor, global_token_type_ids_tensor, labels, is_real_example, add_final_layer=True, label_smoothing=0.0): """Process model output embeddings and computes loss, logits etc.""" global_output_tensor_shape = tensor_utils.get_shape_list( global_output_tensor, expected_rank=3) batch_size = global_output_tensor_shape[0] global_seq_len = global_output_tensor_shape[1] hidden_size = global_output_tensor_shape[2] global_output_tensor = tf.reshape( global_output_tensor, [batch_size * global_seq_len, hidden_size]) if add_final_layer: with tf.variable_scope("global_output_layer/transform"): is_training = True if mode == tf.estimator.ModeKeys.TRAIN else False final_layer = wrappers.ResidualBlock( inner_intermediate_size=model_config.intermediate_size, inner_activation=tensor_utils.get_activation( model_config.hidden_act), use_pre_activation_order=False, dropout_probability=model_config.hidden_dropout_prob) global_output_tensor = final_layer(global_output_tensor, training=is_training) output_weights = tf.get_variable( "output_weights", [1, model_config.hidden_size], initializer=tf.truncated_normal_initializer( stddev=model_config.initializer_range)) output_bias = tf.get_variable("output_bias", [1], initializer=tf.zeros_initializer()) with tf.variable_scope("loss"): logits = tf.matmul(global_output_tensor, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) tf.logging.info("*** logits initial are {} *** ".format(logits)) logits = tf.reshape(logits, [batch_size, global_seq_len]) tf.logging.info("*** logits after reshape are {} *** ".format(logits)) # Consider only candidate global tokens in the global output. multiplier_mask = tf.cast(tf.equal( global_token_type_ids_tensor, multihop_utils.CANDIDATE_GLOBAL_TOKEN_TYPE_ID), dtype=logits.dtype) adder_mask = -10000.0 * (1.0 - multiplier_mask) logits = (logits * multiplier_mask + adder_mask) tf.logging.info("*** global_token_type_ids_tensor is {} *** ".format( global_token_type_ids_tensor)) tf.logging.info("*** adder_mask is {} *** ".format(adder_mask)) tf.logging.info( "*** multiplier_mask is {} *** ".format(multiplier_mask)) tf.logging.info("*** logits computed are {} *** ".format(logits)) # probabilities = tf.nn.softmax(logits, axis=-1) log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(labels, depth=global_seq_len, dtype=tf.float32) if label_smoothing > 0: num_classes = tf.reduce_sum(multiplier_mask, axis=-1) num_classes = tf.expand_dims(num_classes, -1) one_hot_labels = (1 - label_smoothing) * one_hot_labels one_hot_labels += (label_smoothing / num_classes) # Ensure smoothing of labels only for applicable global (candidate) # tokens. one_hot_labels *= multiplier_mask per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) numerator = tf.reduce_sum(per_example_loss * is_real_example) denominator = tf.reduce_sum(is_real_example) + 1e-5 loss = numerator / denominator return (loss, per_example_loss, logits)
def make_fixed_block_side_inputs( input_mask: tf.Tensor, num_tokens_per_block: int, local_radius: int, relative_pos_max_distance: int, use_hard_g2l_mask: bool = False, use_hard_l2g_mask: bool = False, global_token_id: int = 1, name: Optional[Text] = None ) -> Tuple[GlobalLocalTransformerSideInputs, tf.Tensor]: """Utility for creating side inputs in a "fixed blocks" pattern. The "fixed blocks" experiments for NQ and OpenKP are implemented via example generation rather than using this function, but we include this function to illustrate how side inputs can be generated given just a BERT-style `input_mask` feature. The corresponding global tokens are generated as part of this function too, so no global features are required as input. Args: input_mask: <int32>[batch_size, long_seq_len] Tensor of 1 and 0 values, with 1 for actual tokens and 0 for padding. This is the same format as original BERT. `long_seq_len` must be statically known. num_tokens_per_block: Positive integer number of long tokens to assign to each global token. For pre-training on the original BERT data (which was also used for ETC pre-training), the dataset implied a value of about 27, but values like 16 or 32 would also be reasonable. local_radius: How many tokens to the left/right for input tokens to locally self-attend to. For example, a value of 1 would allow each token to only attend to 1 token to the left and 1 token to the right of it. relative_pos_max_distance: Maximum distance to use for relative position representations. All larger distances will be clipped to this value. Use 0 to skip relative position representations entirely. use_hard_g2l_mask: If True, global tokens only attend to tokens of their corresponding block in the long input. If False, global tokens attend to all non-padding long tokens. False is the default setup. use_hard_l2g_mask: If True, long tokens only attend to the global token corresponding to their block. If False, long tokens attend to all the non-padding global tokens. False is the default setup. global_token_id: Integer id to use for global tokens. The default is `1`, which was the value used during ETC pre-training. name: A name for the operation (optional). Returns: A tuple with the following 2 elements: side_inputs: A `GlobalLocalTransformerSideInputs` object containing all side input tensors. global_token_ids: <int32>[batch_size, global_seq_len] Tensor of global tokens ids suitable to pass into `EtcModel`. All global tokens will use the same `global_token_id`, except for padding tokens. """ if num_tokens_per_block <= 0: raise ValueError('`num_tokens_per_block` must be positive.') with tf.name_scope(name or 'make_fixed_block_side_inputs'): input_mask = tf.convert_to_tensor(input_mask) batch_size = tensor_utils.get_shape_list(input_mask)[0] long_seq_len = input_mask.shape.as_list()[1] if long_seq_len is None: raise ValueError('`long_seq_len` must be statically known.') global_seq_len = (long_seq_len + num_tokens_per_block - 1) // num_tokens_per_block # [batch_size, global_seq_len, num_tokens_per_block] blocked_input_mask = tensor_utils.split_into_blocks( input_mask, block_len=num_tokens_per_block, axis=-1) assert blocked_input_mask.shape.as_list()[1] == global_seq_len # [batch_size, global_seq_len] global_input_mask = tf.minimum( tf.reduce_max(blocked_input_mask, axis=-1), 1) # [long_seq_len] sentence_ids = tf.repeat(tf.range(global_seq_len, dtype=tf.int32), num_tokens_per_block)[:long_seq_len] # [batch_size, long_seq_len] sentence_ids = tf.broadcast_to(sentence_ids, [batch_size, long_seq_len]) side_inputs = make_global_local_transformer_side_inputs_from_example_ids( long_example_ids=input_mask, global_example_ids=global_input_mask, sentence_ids=sentence_ids, local_radius=local_radius, relative_pos_max_distance=relative_pos_max_distance, use_hard_g2l_mask=use_hard_g2l_mask, use_hard_l2g_mask=use_hard_l2g_mask) global_token_ids = global_token_id * global_input_mask return side_inputs, global_token_ids
def make_global_local_transformer_side_inputs_from_example_ids( long_example_ids: tf.Tensor, global_example_ids: tf.Tensor, sentence_ids: tf.Tensor, local_radius: int, relative_pos_max_distance: int, use_hard_g2l_mask: bool = False, use_hard_l2g_mask: bool = False, name: Optional[Text] = None) -> GlobalLocalTransformerSideInputs: """Makes side input tensors based on the given example and sentence ids. When packing examples (e.g. for pre-training), each example must have a unique id for `long_example_ids`/`global_example_ids`, and padding must also have a unique id distinct from all the example ids. When not packing examples, there will simply be two unique ids: one for example tokens, and another for padding. Note that in this case, the classic BERT `input_mask` is a valid special case of `long_example_ids`. The other arguments have the same interpretation as in `make_global_local_transformer_side_inputs`. Args: long_example_ids: <int32>[batch_size, long_seq_len] Tensor of example ids of different packed examples. global_example_ids: <int32>[batch_size, global_seq_len] Tensor of example ids of different packed examples. sentence_ids: <int32>[batch_size, long_seq_len] Tensor of ids indicating which sentence each token belongs to. For this dataset, "sentence" refers to real natural language sentence, not a BERT "sentence" from the "next sentence prediction" task. local_radius: How many tokens to the left/right for input tokens to locally self-attend to. For example, a value of 1 would allow each token to only attend to 1 token to the left and 1 token to the right of it. relative_pos_max_distance: Maximum distance to use for relative position representations. All larger distances will be clipped to this value. Use 0 to skip relative position representations entirely. use_hard_g2l_mask: If True, global tokens only attend to tokens of the corresponding sentences in the long input. If False, global tokens attend to all sentences within the corresponding global example. use_hard_l2g_mask: If True, long tokens only attend to tokens of the corresponding global tokens. If False, long tokens attend to all the global tokens within the corresponding global example. name: A name for the operation (optional). Returns: A `GlobalLocalTransformerSideInputs` with all relevant tensors set. """ with tf.name_scope(name or 'make_global_local_transformer_side_inputs'): long_example_ids = tf.convert_to_tensor(long_example_ids) global_example_ids = tf.convert_to_tensor(global_example_ids) sentence_ids = tf.convert_to_tensor(sentence_ids) long_seq_len = tensor_utils.get_shape_list(long_example_ids)[1] global_seq_len = tensor_utils.get_shape_list(global_example_ids)[1] l2l_att_mask = feature_utils.make_local_segmented_att_mask( long_example_ids, local_radius) g2g_att_mask = feature_utils.make_segmented_att_mask( global_example_ids) l2g_att_mask = tf.cast( tf.equal(long_example_ids[:, :, tf.newaxis], global_example_ids[:, tf.newaxis, :]), tf.int32) g2l_att_mask = tf.transpose(l2g_att_mask, perm=[0, 2, 1]) if use_hard_g2l_mask: # Have each global token attend to just one sentence instead of having # it attend to all the sentences within a global example. global_range = tf.range(global_seq_len, dtype=sentence_ids.dtype) hard_g2l_att_mask = tf.cast( tf.equal(global_range[tf.newaxis, :, tf.newaxis], sentence_ids[:, tf.newaxis, :]), tf.int32) g2l_att_mask *= hard_g2l_att_mask if use_hard_l2g_mask: # Have each long token attend to just the corresponding global token # instead of having it attend to all the global tokens within a # global example. global_range = tf.range(global_seq_len, dtype=sentence_ids.dtype) hard_l2g_att_mask = tf.cast( tf.equal(sentence_ids[:, :, tf.newaxis], global_range[tf.newaxis, tf.newaxis, :]), tf.int32) l2g_att_mask *= hard_l2g_att_mask batch_size = tf.shape(long_example_ids)[0] l2l_relative_att_ids = None g2g_relative_att_ids = None l2g_relative_att_ids = None g2l_relative_att_ids = None if relative_pos_max_distance > 0: relative_pos_generator = feature_utils.RelativePositionGenerator( relative_pos_max_distance) l2l_relative_att_ids = relative_pos_generator.make_local_relative_att_ids( seq_len=long_seq_len, local_radius=local_radius, batch_size=batch_size) g2g_relative_att_ids = relative_pos_generator.make_relative_att_ids( seq_len=global_seq_len, batch_size=batch_size) global_range = tf.range(global_seq_len, dtype=sentence_ids.dtype) l2g_relative_att_ids = tf.cast( tf.equal(sentence_ids[:, :, tf.newaxis], global_range[tf.newaxis, tf.newaxis, :]), tf.int32) g2l_relative_att_ids = tf.transpose(l2g_relative_att_ids, perm=[0, 2, 1]) # For fused attention, l2l and l2g share the same relative vocabulary, as # do g2g and g2l, so we add an offset for l2g and g2l so their original # 0/1 ids don't collide with l2l and g2g relative position ids. l2g_relative_att_ids += relative_pos_generator.relative_vocab_size g2l_relative_att_ids += relative_pos_generator.relative_vocab_size return GlobalLocalTransformerSideInputs( l2l_att_mask=l2l_att_mask, g2g_att_mask=g2g_att_mask, l2g_att_mask=l2g_att_mask, g2l_att_mask=g2l_att_mask, l2l_relative_att_ids=l2l_relative_att_ids, g2g_relative_att_ids=g2g_relative_att_ids, l2g_relative_att_ids=l2g_relative_att_ids, g2l_relative_att_ids=g2l_relative_att_ids)
def make_local_segmented_att_mask(segment_ids: tf.Tensor, local_radius: int, name: Optional[Text] = None) -> tf.Tensor: """Makes local attention mask preventing attention across different segments. Restricts local self-attention to attend within segments, such that tokens can only attend to local tokens from the same segment id. The tokens in a segment do not need to be contiguous, but attention is still constrained by `local_radius`. The output can be used as `l2l_att_mask` in `layers.GlobalLocalTransformerLayers` for example. Args: segment_ids: <int32>[batch_size, seq_len] Tensor of segment ids, all of which must be non-negative. local_radius: The local radius as expected by `layers.GlobalLocalTransformerLayers`. Must be positive. name: A name for the operation (optional). Returns: <int32>[batch_size, seq_len, 2*local_radius + 1] attention mask. """ with tf.name_scope(name or 'make_local_segmented_att_mask'): segment_ids = tf.convert_to_tensor(segment_ids) if segment_ids.shape.rank != 2: raise ValueError('`segment_ids` must be a 2-D tensor.') batch_size, seq_len = tensor_utils.get_shape_list(segment_ids) # Add 1 so that segment id `0` doesn't coincide with `0` padding values # introduced later by `tensor_utils.concat_3_blocks()` for example. segment_ids += 1 # [batch_size, num_blocks, local_radius] blocked_segment_ids = tensor_utils.split_into_blocks( segment_ids, block_len=local_radius, axis=1) # [batch_size, num_blocks, 3*local_radius] concat_blocked_segment_ids = tensor_utils.concat_3_blocks( blocked_segment_ids) # [batch_size, num_blocks, local_radius, 3*local_radius] tiled_segment_ids = tf.tile( concat_blocked_segment_ids[:, :, tf.newaxis, :], [1, 1, local_radius, 1]) # [batch_size, num_blocks, local_radius, 2*local_radius + 1] blocked_unskewed_segment_ids = tensor_utils.unskew_elements_right( tiled_segment_ids, axis=-1) # [batch_size, num_blocks * local_radius, 2*local_radius + 1] flat_unskewed_segment_ids = tensor_utils.flatten_dims( blocked_unskewed_segment_ids, first_dim=1, last_dim=2) # [batch_size, seq_len, 2*local_radius + 1] unskewed_segment_ids = tf.slice(flat_unskewed_segment_ids, begin=[0, 0, 0], size=[-1, seq_len, -1]) # [batch_size, seq_len, 1] center_token_segment_id = unskewed_segment_ids[:, :, local_radius:( local_radius + 1)] # [batch_size, seq_len, 2*local_radius + 1] result = tf.cast( tf.equal(unskewed_segment_ids, center_token_segment_id), tf.int32) # Use `reshape` to set the static shape when known. return tf.reshape(result, [batch_size, seq_len, 2 * local_radius + 1])