def build_model(etc_model_config: modeling.EtcConfig, features: Dict[str, tf.Tensor], flat_sequence: bool, is_training: bool, answer_encoding_method: str, use_tpu: bool, use_wordpiece: bool): """Build the ETC HotpotQA model.""" long_token_ids = features["long_token_ids"] long_sentence_ids = features["long_sentence_ids"] long_paragraph_ids = features["long_paragraph_ids"] long_paragraph_breakpoints = features["long_paragraph_breakpoints"] long_token_type_ids = features["long_token_type_ids"] global_token_ids = features["global_token_ids"] global_paragraph_breakpoints = features["global_paragraph_breakpoints"] global_token_type_ids = features["global_token_type_ids"] model = modeling.EtcModel(config=etc_model_config, is_training=is_training, use_one_hot_relative_embeddings=use_tpu) model_inputs = dict(token_ids=long_token_ids, global_token_ids=global_token_ids, segment_ids=long_token_type_ids, global_segment_ids=global_token_type_ids) cls_token_id = (generate_tf_examples_lib. SENTENCEPIECE_DEFAULT_GLOBAL_TOKEN_IDS["CLS_TOKEN_ID"]) if use_wordpiece: cls_token_id = (generate_tf_examples_lib. WORDPIECE_DEFAULT_GLOBAL_TOKEN_IDS["CLS_TOKEN_ID"]) model_inputs.update( qa_input_utils.make_global_local_transformer_side_inputs( long_paragraph_breakpoints=long_paragraph_breakpoints, long_paragraph_ids=long_paragraph_ids, long_sentence_ids=long_sentence_ids, global_paragraph_breakpoints=global_paragraph_breakpoints, local_radius=etc_model_config.local_radius, relative_pos_max_distance=etc_model_config. relative_pos_max_distance, use_hard_g2l_mask=etc_model_config.use_hard_g2l_mask, ignore_hard_g2l_mask=tf.cast(tf.equal(global_token_ids, cls_token_id), dtype=long_sentence_ids.dtype), flat_sequence=flat_sequence, use_hard_l2g_mask=etc_model_config.use_hard_l2g_mask).to_dict( exclude_none_values=True)) long_output, global_output = model(**model_inputs) batch_size, long_seq_length, long_hidden_size = tensor_utils.get_shape_list( long_output, expected_rank=3) _, global_seq_length, global_hidden_size = tensor_utils.get_shape_list( global_output, expected_rank=3) long_output_matrix = tf.reshape( long_output, [batch_size * long_seq_length, long_hidden_size]) global_output_matrix = tf.reshape( global_output, [batch_size * global_seq_length, global_hidden_size]) # Get the logits for the supporting facts predictions. supporting_facts_output_weights = tf.get_variable( "supporting_facts_output_weights", [1, global_hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) supporting_facts_output_bias = tf.get_variable( "supporting_facts_output_bias", [1], initializer=tf.zeros_initializer()) supporting_facts_logits = tf.matmul(global_output_matrix, supporting_facts_output_weights, transpose_b=True) supporting_facts_logits = tf.nn.bias_add(supporting_facts_logits, supporting_facts_output_bias) supporting_facts_logits = tf.reshape(supporting_facts_logits, [batch_size, global_seq_length]) # Get the logits for the answer type prediction. num_answer_types = 3 # SPAN, YES, NO answer_type_output_weights = tf.get_variable( "answer_type_output_weights", [num_answer_types, global_hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) answer_type_output_bias = tf.get_variable( "answer_type_output_bias", [num_answer_types], initializer=tf.zeros_initializer()) answer_type_logits = tf.matmul(global_output[:, 0, :], answer_type_output_weights, transpose_b=True) answer_type_logits = tf.nn.bias_add(answer_type_logits, answer_type_output_bias) extra_model_losses = model.losses if answer_encoding_method == "span": # Get the logits for the begin and end indices. answer_span_output_weights = tf.get_variable( "answer_span_output_weights", [2, long_hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) answer_span_output_bias = tf.get_variable( "answer_span_output_bias", [2], initializer=tf.zeros_initializer()) answer_span_logits = tf.matmul(long_output_matrix, answer_span_output_weights, transpose_b=True) answer_span_logits = tf.nn.bias_add(answer_span_logits, answer_span_output_bias) answer_span_logits = tf.reshape(answer_span_logits, [batch_size, long_seq_length, 2]) answer_span_logits = tf.transpose(answer_span_logits, [2, 0, 1]) answer_begin_logits, answer_end_logits = tf.unstack(answer_span_logits, axis=0) return (supporting_facts_logits, (answer_begin_logits, answer_end_logits), answer_type_logits, extra_model_losses) else: # Get the logits for the answer BIO encodings. answer_bio_output_weights = tf.get_variable( "answer_bio_output_weights", [3, long_hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) answer_type_output_bias = tf.get_variable( "answer_bio_output_bias", [3], initializer=tf.zeros_initializer()) answer_bio_logits = tf.matmul(long_output_matrix, answer_bio_output_weights, transpose_b=True) answer_bio_logits = tf.nn.bias_add(answer_bio_logits, answer_type_output_bias) answer_bio_logits = tf.reshape(answer_bio_logits, [batch_size, long_seq_length, 3]) return (supporting_facts_logits, answer_bio_logits, answer_type_logits, extra_model_losses)
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s", name, features[name].shape) long_token_ids = features["long_token_ids"] long_token_type_ids = features["long_token_type_ids"] global_token_ids = features["global_token_ids"] global_token_type_ids = features["global_token_type_ids"] model_inputs = dict(token_ids=long_token_ids, global_token_ids=global_token_ids, global_segment_ids=global_token_type_ids, segment_ids=long_token_type_ids) for field in attr.fields(input_utils.GlobalLocalTransformerSideInputs): model_inputs[field.name] = features[field.name] labels = tf.cast(features["label_id"], dtype=tf.int32) is_real_example = None if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(labels), dtype=tf.float32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = modeling.EtcModel( config=model_config, is_training=is_training, use_one_hot_embeddings=use_one_hot_embeddings, use_one_hot_relative_embeddings=use_tpu) _, global_output = model(**model_inputs) (total_loss, per_example_loss, logits) = (process_model_output( model_config, mode, global_output, global_token_type_ids, labels, is_real_example, add_final_layer, label_smoothing)) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = input_utils.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, optimizer, poly_power, start_warmup_step, learning_rate_schedule, weight_decay_rate) metrics_dict = metric_fn(per_example_loss=per_example_loss, logits=logits, labels=labels, is_real_example=is_real_example, is_train=True) host_inputs = { "global_step": tf.expand_dims(tf.train.get_or_create_global_step(), 0), } host_inputs.update({ metric_name: tf.expand_dims(metric_tensor, 0) for metric_name, metric_tensor in metrics_dict.items() }) host_call = (functools.partial(record_summary_host_fn, metrics_dir=os.path.join( model_dir, "train_metrics"), steps_per_summary=50), host_inputs) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn, host_call=host_call) elif mode == tf.estimator.ModeKeys.EVAL: metric_fn_tensors = dict(per_example_loss=per_example_loss, logits=logits, labels=labels, is_real_example=is_real_example) eval_metrics = (functools.partial(metric_fn, is_train=False), metric_fn_tensors) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.PREDICT: output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions={ "logits": logits, # Wrap in `tf.identity` to avoid b/130501786. "label_ids": tf.identity(labels), }, scaffold_fn=scaffold_fn) else: raise ValueError("Unexpected mode {} encountered".format(mode)) return output_spec
def build_model(etc_model_config, features, is_training, flags): """Build an ETC model.""" token_ids = features["token_ids"] global_token_ids = features["global_token_ids"] model = modeling.EtcModel(config=etc_model_config, is_training=is_training, use_one_hot_relative_embeddings=flags.use_tpu) model_inputs = dict(token_ids=token_ids, global_token_ids=global_token_ids) for field in attr.fields(input_utils.GlobalLocalTransformerSideInputs): if field.name in features: model_inputs[field.name] = features[field.name] # Get the logits for the start and end predictions. l_final_hidden, _ = model(**model_inputs) l_final_hidden_shape = tensor_utils.get_shape_list(l_final_hidden, expected_rank=3) batch_size = l_final_hidden_shape[0] l_seq_length = l_final_hidden_shape[1] hidden_size = l_final_hidden_shape[2] num_answer_types = 5 # NULL, YES, NO, LONG, SHORT # We add a dense layer to the long output: l_output_weights = tf.get_variable( "cls/nq/long_output_weights", [4, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) l_output_bias = tf.get_variable("cls/nq/long_output_bias", [4], initializer=tf.zeros_initializer()) l_final_hidden_matrix = tf.reshape( l_final_hidden, [batch_size * l_seq_length, hidden_size]) l_logits = tf.matmul(l_final_hidden_matrix, l_output_weights, transpose_b=True) l_logits = tf.nn.bias_add(l_logits, l_output_bias) l_logits = tf.reshape(l_logits, [batch_size, l_seq_length, 4]) if flags.mask_long_output: # Mask out invalid SA/LA start/end positions: # 1) find the SEP and CLS tokens: long_sep = tf.cast(tf.equal(token_ids, flags.sep_tok_id), tf.int32) long_not_sep = 1 - long_sep long_cls = tf.cast(tf.equal(token_ids, flags.cls_tok_id), tf.int32) # 2) accum sum the SEPs, and the only possible answers are those with sum # equal to 1 (except SEPs) and the CLS position l_mask = tf.cast(tf.equal(tf.cumsum(long_sep, axis=-1), 1), tf.int32) l_mask = 1 - ((l_mask * long_not_sep) + long_cls) # 3) apply the mask to the logits l_mask = tf.expand_dims(tf.cast(l_mask, tf.float32) * -10E8, 2) l_logits = tf.math.add(l_logits, l_mask) # Get the logits for the answer type prediction. answer_type_output_layer = l_final_hidden[:, 0, :] answer_type_hidden_size = answer_type_output_layer.shape[-1].value answer_type_output_weights = tf.get_variable( "answer_type_output_weights", [num_answer_types, answer_type_hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) answer_type_output_bias = tf.get_variable( "answer_type_output_bias", [num_answer_types], initializer=tf.zeros_initializer()) answer_type_logits = tf.matmul(answer_type_output_layer, answer_type_output_weights, transpose_b=True) answer_type_logits = tf.nn.bias_add(answer_type_logits, answer_type_output_bias) extra_model_losses = model.losses l_logits = tf.transpose(l_logits, [2, 0, 1]) l_unstacked_logits = tf.unstack(l_logits, axis=0) return ([l_unstacked_logits[i] for i in range(4)], answer_type_logits, extra_model_losses)
def _build_model(model_config, features, is_training, flags): """Build an ETC model for OpenKP.""" global_embedding_adder = None long_embedding_adder = None # Create `global_embedding_adder` if using visual features. if flags.use_visual_features_in_global or flags.use_visual_features_in_long: global_embedding_adder = _create_global_visual_feature_embeddings( model_config, features, flags) if flags.use_visual_features_in_long: # Create `long_embedding_adder` based on `global_embedding_adder` long_embedding_adder = gather_global_embeddings_to_long( global_embedding_adder, features['long_vdom_idx']) if not flags.use_visual_features_in_global: global_embedding_adder = None model = modeling.EtcModel( config=model_config, is_training=is_training, use_one_hot_relative_embeddings=flags.use_tpu) model_inputs = dict( token_ids=features['long_token_ids'], global_token_ids=features['global_token_ids'], long_embedding_adder=long_embedding_adder, global_embedding_adder=global_embedding_adder) for field in attr.fields(input_utils.GlobalLocalTransformerSideInputs): model_inputs[field.name] = features[field.name] long_output, _ = model(**model_inputs) word_embeddings_unnormalized = batch_segment_sum_embeddings( long_embeddings=long_output, long_word_idx=features['long_word_idx'], long_input_mask=features['long_input_mask']) word_emb_layer_norm = tf.keras.layers.LayerNormalization( axis=-1, epsilon=1e-12, name='word_emb_layer_norm') word_embeddings = word_emb_layer_norm(word_embeddings_unnormalized) ngram_logit_list = [] for i in range(flags.kp_max_length): conv = tf.keras.layers.Conv1D( filters=model_config.hidden_size, kernel_size=i + 1, padding='valid', activation=tensor_utils.get_activation('gelu'), kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=0.02 / math.sqrt(i + 1)), name=f'{i + 1}gram_conv') layer_norm = tf.keras.layers.LayerNormalization( axis=-1, epsilon=1e-12, name=f'{i + 1}gram_layer_norm') logit_dense = tf.keras.layers.Dense( units=1, activation=None, use_bias=False, kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), name=f'logit_dense{i}') # [batch_size, long_max_length - i] unpadded_logits = tf.squeeze( logit_dense(layer_norm(conv(word_embeddings))), axis=-1) # Pad to the right to get back to `long_max_length`. padded_logits = tf.pad(unpadded_logits, paddings=[[0, 0], [0, i]]) # Padding logits should be ignored, so we make a large negative mask adder # for them. shifted_word_mask = tf.cast( tensor_utils.shift_elements_right( features['long_word_input_mask'], axis=-1, amount=-i), dtype=padded_logits.dtype) mask_adder = -10000.0 * (1.0 - shifted_word_mask) ngram_logit_list.append(padded_logits * shifted_word_mask + mask_adder) # [batch_size, kp_max_length, long_max_length] ngram_logits = tf.stack(ngram_logit_list, axis=1) extra_model_losses = model.losses return ngram_logits, extra_model_losses