def build_key(self): with tf.compat.v1.variable_scope("embeddings"): input_tensor = self.get_embeddings(self.input_ids, self.segment_ids) self.input_shape = bc.get_shape_list(input_tensor, expected_rank=3) with tf.compat.v1.variable_scope("encoder"): self.attention_mask = bc.create_attention_mask_from_input_mask( input_tensor, self.input_mask) prev_output = bc.reshape_to_matrix(input_tensor) for layer_idx in range(self.layers_before_key_pooling): with tf.compat.v1.variable_scope("layer_%d" % layer_idx): intermediate_output, prev_output = self.forward_layer( prev_output) intermediate_output = tf.reshape(intermediate_output, [ self.batch_size * self.seq_length, self.config.intermediate_size ]) final_output = bc.reshape_from_matrix( prev_output, self.input_shape) self.all_layer_outputs.append(final_output) self.last_intermediate_output = intermediate_output self.last_key_layer = prev_output with tf.compat.v1.variable_scope("mr_key"): key_vectors = bc.dense(self.key_dimension, self.initializer)(intermediate_output) self.debug1 = key_vectors key_vectors = tf.reshape( key_vectors, [self.batch_size, self.seq_length, self.key_dimension]) key_output = self.key_pooling(key_vectors) return key_output
def get_lexical_lookup(self): input_tensor_var = tf.compat.v1.get_variable( name="base_second", shape=[self.config.hidden_size], initializer=bc.create_initializer(self.config.initializer_range)) batch_size, seq_length = bc.get_shape_list(self.input_ids) input_tensor = tf.reshape(input_tensor_var, [1, 1, -1]) return input_tensor
def build_by_attention(self, key): hidden_size = self.config.hidden_size with tf.compat.v1.variable_scope("embeddings"): lexical_tensor = self.get_lexical_lookup() self.embedding_output = self.embedding_postprocessor( d_input_ids=self.input_ids, input_tensor=lexical_tensor, use_token_type=True, token_type_ids=self.segment_ids, token_type_vocab_size=self.config.type_vocab_size, token_type_embedding_name="token_type_embeddings", use_position_embeddings=True, position_embedding_name="position_embeddings", initializer_range=self.config.initializer_range, max_position_embeddings=self.config.max_position_embeddings, dropout_prob=self.config.hidden_dropout_prob) input_tensor = self.embedding_output #[ def_per_batch, seq_length, hidden_size] with tf.compat.v1.variable_scope("encoder"): num_key_tokens = self.ssdr_config.num_key_tokens project_dim = hidden_size * num_key_tokens raw_key = bc.dense(project_dim, self.initializer)(key) key_tokens = tf.reshape( raw_key, [self.batch_size, num_key_tokens, hidden_size]) input_tensor = tf.concat([key_tokens, input_tensor], axis=1) input_shape = bc.get_shape_list(input_tensor, expected_rank=3) mask_for_key = tf.ones([self.batch_size, num_key_tokens], dtype=tf.int64) self.input_mask = tf.cast(self.input_mask, tf.int64) self.input_mask = tf.concat([mask_for_key, self.input_mask], axis=1) self.seq_length = self.seq_length + num_key_tokens self.attention_mask = bc.create_attention_mask_from_input_mask( input_tensor, self.input_mask) prev_output = bc.reshape_to_matrix(input_tensor) for layer_idx in range(self.ssdr_config.num_hidden_layers): with tf.compat.v1.variable_scope("layer_%d" % layer_idx): intermediate_output, prev_output = self.forward_layer( prev_output) self.all_layer_outputs.append(prev_output) final_output = bc.reshape_from_matrix(prev_output, input_shape) self.scores = bc.dense(1, self.initializer)(final_output[:, 0, :]) if self.ssdr_config.info_pooling_method == "first_tokens": self.info_output = final_output[:, :num_key_tokens, :] elif self.ssdr_config.info_pooling_method == "max_pooling": self.info_output = tf.reduce_max(final_output, axis=1) return self.scores, self.info_output
def sigmoid_all(all_logits, label_ids): print('all_logits', all_logits) print('logits', all_logits) batch_size, _, num_seg = get_shape_list(all_logits) lable_ids_tile = tf.cast( tf.tile(tf.expand_dims(label_ids, 2), [1, 1, num_seg]), tf.float32) print('label_ids', label_ids) losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=all_logits, labels=lable_ids_tile) loss = tf.reduce_mean(losses) probs = tf.nn.sigmoid(all_logits) logits = tf.reduce_mean(probs, axis=2) return logits, loss
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument tf_logging.info("model_fn_ranking") log_features(features) input_ids, input_mask, segment_ids = combine_paired_input_features( features) batch_size, _ = get_shape_list( input_mask) # This is not real batch_size, 2 * real_batch_size use_context = tf.ones([batch_size, 1], tf.int32) stacked_input_ids, stacked_input_mask, stacked_segment_ids, \ = split_and_append_sep(input_ids, input_mask, segment_ids, config.total_sequence_length, config.window_size, CLS_ID, EOW_ID) is_training = (mode == tf.estimator.ModeKeys.TRAIN) with tf.compat.v1.variable_scope("sero"): model = model_class(config, is_training, train_config.use_one_hot_embeddings) sequence_output_3d = model.network_stacked(stacked_input_ids, stacked_input_mask, stacked_segment_ids, use_context) pooled_output = model.get_pooled_output() if is_training: pooled_output = dropout(pooled_output, 0.1) loss, losses, y_pred = apply_loss_modeling(config.loss, pooled_output, features) assignment_fn = get_assignment_map_from_checkpoint_type( train_config.checkpoint_type, config.lower_layers) scaffold_fn = checkpoint_init(assignment_fn, train_config) prediction = { "stacked_input_ids": stacked_input_ids, "stacked_input_mask": stacked_input_mask, "stacked_segment_ids": stacked_segment_ids, } if train_config.gradient_accumulation != 1: optimizer_factory = lambda x: grad_accumulation.get_accumulated_optimizer_from_config( x, train_config, tf.compat.v1.trainable_variables(), train_config.gradient_accumulation) else: optimizer_factory = lambda x: create_optimizer_from_config( x, train_config) return ranking_estimator_spec(mode, loss, losses, y_pred, scaffold_fn, optimizer_factory, prediction)
def gather_indexes(sequence_tensor, positions): """Gathers the vectors at the specific positions over a minibatch.""" sequence_shape = bert_common.get_shape_list(sequence_tensor, expected_rank=3) batch_size = sequence_shape[0] seq_length = sequence_shape[1] width = sequence_shape[2] flat_offsets = tf.reshape( tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) flat_positions = tf.reshape(positions + flat_offsets, [-1]) flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width]) output_tensor = tf.gather(flat_sequence_tensor, flat_positions) return output_tensor
def __init__(self, config, # This is different from BERT config, is_training, input_ids, input_mask, token_type_ids, use_one_hot_embeddings, features, ): super(MultiContextEncoder, self).__init__() self.config = config if not is_training: config.set_attrib("hidden_dropout_prob", 0.0) config.set_attrib("attention_probs_dropout_prob", 0.0) def reform_context(context): return tf.reshape(context, [-1, config.max_context, config.max_context_length]) batch_size, _ = get_shape_list(input_ids) def combine(input_ids, context_input_ids): a = tf.tile(tf.expand_dims(input_ids, 1), [1, config.max_context, 1]) b = reform_context(context_input_ids) rep_3d = tf.concat([a, b], 2) return tf.reshape(rep_3d, [batch_size * config.max_context, -1]) context_input_ids = features["context_input_ids"] context_input_mask = features["context_input_mask"] context_segment_ids = features["context_segment_ids"] context_segment_ids = tf.ones_like(context_segment_ids, tf.int32) * 2 self.module = BertModel(config=config, is_training=is_training, input_ids=combine(input_ids, context_input_ids), input_mask=combine(input_mask, context_input_mask), token_type_ids=combine(token_type_ids, context_segment_ids), use_one_hot_embeddings=use_one_hot_embeddings, ) dense_layer_setup = tf.keras.layers.Dense(config.hidden_size, activation=tf.keras.activations.tanh, kernel_initializer=create_initializer(config.initializer_range)) h1 = self.module.get_pooled_output() h2 = dense_layer_setup(h1) h2 = tf.reshape(h2, [batch_size, config.max_context, -1]) h2 = h2[:, :config.num_context] h3 = tf.reduce_mean(h2, axis=1) h4 = dense_layer_setup(h3) self.pooled_output = h4
def get_dummy_apr_input(input_ids, input_mask, def_per_batch, inner_batch_size, max_loc_length, max_position_embeddings): b_shape = [def_per_batch, max_position_embeddings] common_dummy = tf.zeros(b_shape, dtype=tf.int64) d_input_ids = common_dummy d_input_mask = common_dummy d_segment_ids = common_dummy batch_size = bc.get_shape_list(input_ids)[0] d_location_ids = pool_location_id(input_ids, input_mask, max_loc_length) n_repeat = int(def_per_batch / inner_batch_size) seq = tf.range(inner_batch_size) seq = tf.expand_dims(seq, 1) ab_mapping = tf.reshape(tf.tile(seq, [1, n_repeat]), [1, -1]) ab_mapping_mask = create_ab_mapping_mask(inner_batch_size, def_per_batch) return d_input_ids, d_input_mask, d_segment_ids, d_location_ids, ab_mapping, ab_mapping_mask
def build(self): vocab_size = 40000 embedding_size = 512 initializer = tf.compat.v1.truncated_normal_initializer(stddev=0.02) embedding_table = tf.compat.v1.get_variable( name="embedding", shape=[vocab_size, embedding_size], initializer=initializer) seq_length = 512 input_ids = tf.keras.layers.Input(shape=(seq_length,)) input_shape = get_shape_list(input_ids) flat_input_ids = tf.reshape(input_ids, [-1]) one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) output = tf.matmul(one_hot_input_ids, embedding_table) output = tf.reshape(output, input_shape + [embedding_size]) t_list = [] n_of_t = 10 for j in range(n_of_t): t_list.append(output+j) dense = tf.keras.layers.Dense(embedding_size, kernel_initializer=initializer, name="MDense") n_of_t = 10 for i in range(20): t = tf.stack(t_list, 0) with tf.compat.v1.variable_scope("scope_A", reuse=i > 0): t = dense(t) t = tf.nn.dropout(t, rate=0.5) t_0 = 0 for j in range(1, n_of_t): t_0 += t[j] new_t_list = [t_0] for j in range(1, n_of_t): new_t_list.append(t[j]) t_list = new_t_list
def build(self, value_out, locations): with tf.compat.v1.variable_scope("embeddings"): input_tensor = self.get_embeddings(self.input_ids, self.segment_ids) self.input_shape = bc.get_shape_list(input_tensor, expected_rank=3) with tf.compat.v1.variable_scope("encoder"): self.attention_mask = bc.create_attention_mask_from_input_mask( input_tensor, self.input_mask) prev_output = bc.reshape_to_matrix(input_tensor) prev_output = tf.tensor_scatter_nd_update(prev_output, locations, value_out) for layer_idx in range(self.config.num_hidden_layers): with tf.compat.v1.variable_scope("layer_%d" % layer_idx): intermediate_output, prev_output = self.forward_layer( prev_output) final_output = bc.reshape_from_matrix( prev_output, self.input_shape) self.all_layer_outputs.append(final_output) return self.all_layer_outputs
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument tf_logging.info("model_fn_sero_ranking_predict") """The `model_fn` for TPUEstimator.""" log_features(features) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] batch_size, _ = get_shape_list(input_mask) use_context = tf.ones([batch_size, 1], tf.int32) stacked_input_ids, stacked_input_mask, stacked_segment_ids, \ = split_and_append_sep(input_ids, input_mask, segment_ids, config.total_sequence_length, config.window_size, CLS_ID, EOW_ID) is_training = (mode == tf.estimator.ModeKeys.TRAIN) # Updated with tf.compat.v1.variable_scope("sero"): model = model_class(config, is_training, train_config.use_one_hot_embeddings) model.network_stacked(stacked_input_ids, stacked_input_mask, stacked_segment_ids, use_context) pooled_output = model.get_pooled_output() logits = get_prediction_structure(config.loss, pooled_output) tvars = tf.compat.v1.trainable_variables() assignment_fn = assignment_map.assignment_map_v2_to_v2 initialized_variable_names, init_fn = get_init_fn( tvars, train_config.init_checkpoint, assignment_fn) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) output_spec = rank_predict_estimator_spec(logits, mode, scaffold_fn) return output_spec
def train_modeling(self, input_tensor, masked_lm_positions, masked_lm_weights, loss_base, loss_target): if self.graph_built: raise Exception() batch_size, _, hidden_dims = get_shape_list(input_tensor) input_tensor = bc.gather_indexes(input_tensor, masked_lm_positions) input_tensor = tf.reshape(input_tensor, [batch_size, -1, hidden_dims]) with tf.compat.v1.variable_scope("project"): hidden = self.layer1(input_tensor) def cross_entropy(logits, loss_label): gold_prob = loss_to_prob_pair(loss_label) logits = tf.reshape(logits, gold_prob.shape) per_example_loss = tf.nn.softmax_cross_entropy_with_logits( gold_prob, logits, axis=-1, name=None) per_example_loss = tf.cast(masked_lm_weights, tf.float32) * per_example_loss losses = tf.reduce_sum(per_example_loss, axis=1) loss = tf.reduce_mean(losses) return loss, per_example_loss with tf.compat.v1.variable_scope("cls1"): self.logits1 = self.logit_dense1(hidden) with tf.compat.v1.variable_scope("cls2"): self.logits2 = self.logit_dense2(hidden) self.loss1, self.per_example_loss1 = cross_entropy( self.logits1, loss_base) self.loss2, self.per_example_loss2 = cross_entropy( self.logits2, loss_target) self.prob1 = tf.nn.softmax(self.logits1)[:, :, 0] self.prob2 = tf.nn.softmax(self.logits2)[:, :, 0] self.total_loss = self.loss1 + self.loss2 self.graph_built = True
def build(self): with tf.compat.v1.variable_scope("dict"): with tf.compat.v1.variable_scope("embeddings"): input_tensor = self.get_embeddings(self.input_ids, self.segment_ids) with tf.compat.v1.variable_scope("encoder"): num_key_tokens = self.ssdr_config.num_key_tokens input_shape = bc.get_shape_list(input_tensor, expected_rank=3) mask_for_key = tf.ones([self.batch_size, num_key_tokens], dtype=tf.int64) self.input_mask = tf.cast(self.input_mask, tf.int64) self.input_mask = tf.concat([mask_for_key, self.input_mask], axis=1) self.seq_length = self.seq_length + num_key_tokens self.attention_mask = bc.create_attention_mask_from_input_mask( input_tensor, self.input_mask) prev_output = bc.reshape_to_matrix(input_tensor) for layer_idx in range(self.ssdr_config.num_hidden_layers): with tf.compat.v1.variable_scope("layer_%d" % layer_idx): intermediate_output, prev_output = self.forward_layer( prev_output) self.all_layer_outputs.append(prev_output) final_output = bc.reshape_from_matrix(prev_output, input_shape) self.scores = bc.dense(1, self.initializer)(final_output[:, 0, :]) if self.ssdr_config.info_pooling_method == "first_tokens": self.info_output = final_output[:, :num_key_tokens, :] elif self.ssdr_config.info_pooling_method == "max_pooling": self.info_output = tf.reduce_max(final_output, axis=1) return self.scores, self.info_output
def embedding_postprocessor( self, d_input_ids, input_tensor, use_token_type=False, token_type_ids=None, token_type_vocab_size=16, token_type_embedding_name="token_type_embeddings", use_position_embeddings=True, position_embedding_name="position_embeddings", initializer_range=0.02, max_position_embeddings=512, dropout_prob=0.1): input_shape = bc.get_shape_list(d_input_ids, expected_rank=2) batch_size = input_shape[0] seq_length = input_shape[1] width = self.config.hidden_size output = input_tensor if use_token_type: if token_type_ids is None: raise ValueError("`token_type_ids` must be specified if" "`use_token_type` is True.") token_type_table = tf.compat.v1.get_variable( name=token_type_embedding_name, shape=[token_type_vocab_size, width], initializer=bc.create_initializer(initializer_range)) # This vocab will be small so we always do one-hot here, since it is always # faster for a small vocabulary. flat_token_type_ids = tf.reshape(token_type_ids, [-1]) one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) token_type_embeddings = tf.reshape(token_type_embeddings, [batch_size, seq_length, width]) output += token_type_embeddings if use_position_embeddings: assert_op = tf.compat.v1.assert_less_equal( seq_length, max_position_embeddings) with tf.control_dependencies([assert_op]): full_position_embeddings = tf.compat.v1.get_variable( name=position_embedding_name, shape=[max_position_embeddings, width], initializer=bc.create_initializer(initializer_range)) # Since the position embedding table is a learned variable, we create it # using a (long) sequence length `max_position_embeddings`. The actual # sequence length might be shorter than this, for faster training of # tasks that do not have long sequences. # # So `full_position_embeddings` is effectively an embedding table # for position [0, 1, 2, ..., max_position_embeddings-1], and the current # sequence has positions [0, 1, 2, ... seq_length-1], so we can just # perform a slice. position_embeddings = tf.slice(full_position_embeddings, [0, 0], [seq_length, -1]) num_dims = len(output.shape.as_list()) # Only the last two dimensions are relevant (`seq_length` and `width`), so # we broadcast among the first dimensions, which is typically just # the batch size. position_broadcast_shape = [] for _ in range(num_dims - 2): position_broadcast_shape.append(1) position_broadcast_shape.extend([seq_length, width]) position_embeddings = tf.reshape(position_embeddings, position_broadcast_shape) output += position_embeddings output = bc.layer_norm_and_dropout(output, dropout_prob) return output
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument tf_logging.info("model_fn_sero_classification") """The `model_fn` for TPUEstimator.""" log_features(features) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] batch_size, _ = get_shape_list(input_mask) use_context = tf.ones([batch_size, 1], tf.int32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) # Updated if modeling == "sero": model_class = SeroDelta print("Using SeroDelta") elif modeling == "sero_epsilon": model_class = SeroEpsilon print("Using SeroEpsilon") else: assert False with tf.compat.v1.variable_scope("sero"): model = model_class(config, is_training, train_config.use_one_hot_embeddings) input_ids = tf.expand_dims(input_ids, 1) input_mask = tf.expand_dims(input_mask, 1) segment_ids = tf.expand_dims(segment_ids, 1) sequence_output = model.network_stacked(input_ids, input_mask, segment_ids, use_context) first_token_tensor = tf.squeeze(sequence_output[:, 0:1, :], axis=1) pooled_output = tf.keras.layers.Dense( config.hidden_size, activation=tf.keras.activations.tanh, kernel_initializer=create_initializer( config.initializer_range))(first_token_tensor) if "bias_loss" in special_flags: loss_weighting = reweight_zero else: loss_weighting = None task = Classification(3, features, pooled_output, is_training, loss_weighting) loss = task.loss tvars = tf.compat.v1.trainable_variables() assignment_fn = assignment_map.assignment_map_v2_to_v2 initialized_variable_names, init_fn = get_init_fn( tvars, train_config.init_checkpoint, assignment_fn) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec if mode == tf.estimator.ModeKeys.TRAIN: tf_logging.info("Using single lr ") train_op = optimization.create_optimizer_from_config( loss, train_config) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: output_spec = TPUEstimatorSpec(mode=model, loss=loss, eval_metrics=task.eval_metrics(), scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.PREDICT: predictions = {"input_ids": input_ids, "logits": task.logits} output_spec = TPUEstimatorSpec(mode=model, loss=loss, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument tf_logging.info("model_fn_sero_lm") """The `model_fn` for TPUEstimator.""" log_features(features) input_ids = features["input_ids"] # [batch_size, seq_length] input_mask = features["input_mask"] segment_ids = features["segment_ids"] is_sero_modeling = "sero" in modeling if is_sero_modeling: use_context = features["use_context"] elif modeling == "bert": batch_size, _ = get_shape_list(input_mask) use_context = tf.ones([batch_size, 1], tf.int32) else: assert False if mode == tf.estimator.ModeKeys.PREDICT: tf.random.set_seed(0) seed = 0 else: seed = None is_training = (mode == tf.estimator.ModeKeys.TRAIN) tf_logging.info("Using masked_input_ids") if is_sero_modeling: stacked_input_ids, stacked_input_mask, stacked_segment_ids, \ = split_and_append_sep(input_ids, input_mask, segment_ids, config.total_sequence_length, config.window_size, CLS_ID, EOW_ID) input_ids_2d = r3to2(stacked_input_ids) input_mask_2d = r3to2(stacked_input_mask) elif modeling == "bert": stacked_input_ids, stacked_input_mask, stacked_segment_ids = input_ids, input_mask, segment_ids input_ids_2d = stacked_input_ids input_mask_2d = stacked_input_mask else: assert False tf_logging.info("Doing dynamic masking (random)") # TODO make stacked_input_ids 2D and recover masked_input_ids_2d, masked_lm_positions_2d, masked_lm_ids_2d, masked_lm_weights_2d \ = random_masking(input_ids_2d, input_mask_2d, train_config.max_predictions_per_seq, MASK_ID, seed, [EOW_ID]) if is_sero_modeling: masked_input_ids = tf.reshape(masked_input_ids_2d, stacked_input_ids.shape) elif modeling == "bert": masked_input_ids = tf.expand_dims(masked_input_ids_2d, 1) stacked_input_mask = tf.expand_dims(stacked_input_mask, 1) stacked_segment_ids = tf.expand_dims(stacked_segment_ids, 1) else: assert False if modeling == "sero": model_class = SeroDelta elif modeling == "sero_epsilon": model_class = SeroEpsilon with tf.compat.v1.variable_scope("sero"): model = model_class(config, is_training, train_config.use_one_hot_embeddings) sequence_output_3d = model.network_stacked(masked_input_ids, stacked_input_mask, stacked_segment_ids, use_context) masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs \ = get_masked_lm_output(config, sequence_output_3d, model.get_embedding_table(), masked_lm_positions_2d, masked_lm_ids_2d, masked_lm_weights_2d) predictions = None if prediction_op == "gradient_to_long_context": predictions = {} for idx, input_tensor in enumerate(model.upper_module_inputs): g = tf.abs(tf.gradients(ys=masked_lm_loss, xs=input_tensor)[0]) main_g = g[:, :config.window_size, :] context_g = g[:, config.window_size:, :] main_g = tf.reduce_mean(tf.reduce_mean(main_g, axis=2), axis=1) context_g = tf.reduce_mean(tf.reduce_mean(context_g, axis=2), axis=1) predictions['main_g_{}'.format(idx)] = main_g predictions['context_g_{}'.format(idx)] = context_g loss = masked_lm_loss #+ bert_task.masked_lm_loss tvars = tf.compat.v1.trainable_variables() if train_config.init_checkpoint: assignment_fn = get_assignment_map_from_checkpoint_type( train_config.checkpoint_type, config.lower_layers) else: assignment_fn = None initialized_variable_names, init_fn = get_init_fn( tvars, train_config.init_checkpoint, assignment_fn) log_var_assignments(tvars, initialized_variable_names) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer_from_config( loss, train_config) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, training_hooks=[OomReportingHook()], scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: output_spec = TPUEstimatorSpec(mode=model, loss=loss, eval_metrics=None, scaffold_fn=scaffold_fn) else: if predictions is None: predictions = { "input_ids": input_ids, "masked_input_ids": masked_input_ids, "masked_lm_ids": masked_lm_ids_2d, "masked_lm_example_loss": masked_lm_example_loss, "masked_lm_positions": masked_lm_positions_2d, } output_spec = TPUEstimatorSpec(mode=mode, loss=loss, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def __init__(self, config, is_training, input_ids, input_mask=None, token_type_ids=None, use_one_hot_embeddings=True, scope=None): """Constructor for BertModel. Args: config: `BertConfig` instance. is_training: bool. rue for training model, false for eval model. Controls whether dropout will be applied. input_ids: int32 Tensor of shape [batch_size, seq_length]. input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. use_one_hot_embeddings: (optional) bool. Whether to use one-hot word embeddings or tf.embedding_lookup() for the word embeddings. On the TPU, it is must faster if this is True, on the CPU or GPU, it is faster if this is False. scope: (optional) variable scope. Defaults to "bert". Raises: ValueError: The config is invalid or one of the input tensor shapes is invalid. """ config = copy.deepcopy(config) if not is_training: config.hidden_dropout_prob = 0.0 config.attention_probs_dropout_prob = 0.0 input_shape = get_shape_list(input_ids, expected_rank=2) batch_size = input_shape[0] seq_length = input_shape[1] if input_mask is None: input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) if token_type_ids is None: token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) with tf.compat.v1.variable_scope(scope, default_name="bert"): with tf.compat.v1.variable_scope("embeddings"): # Perform embedding lookup on the word ids. (self.embedding_output, self.embedding_table) = embedding_lookup( input_ids=input_ids, vocab_size=config.vocab_size, embedding_size=config.hidden_size, initializer_range=config.initializer_range, word_embedding_name="word_embeddings", use_one_hot_embeddings=use_one_hot_embeddings) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. self.embedding_output = embedding_postprocessor( input_tensor=self.embedding_output, use_token_type=True, token_type_ids=token_type_ids, token_type_vocab_size=config.type_vocab_size, token_type_embedding_name="token_type_embeddings", use_position_embeddings=True, position_embedding_name="position_embeddings", initializer_range=config.initializer_range, max_position_embeddings=config.max_position_embeddings, dropout_prob=config.hidden_dropout_prob) with tf.compat.v1.variable_scope("encoder"): # This converts a 2D mask of shape [batch_size, seq_length] to a 3D # mask of shape [batch_size, seq_length, seq_length] which is used # for the attention scores. attention_mask = create_attention_mask_from_input_mask( input_ids, input_mask) # Run the stacked transformer. # `sequence_output` shape = [batch_size, seq_length, hidden_size]. self.all_encoder_layers, key = transformer_model( input_tensor=self.embedding_output, attention_mask=attention_mask, input_mask=input_mask, hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers, num_attention_heads=config.num_attention_heads, is_training=is_training, #mr_layer=config.mr_layer, mr_num_route=config.mr_num_route, #mr_key_layer=config.mr_key_layer, intermediate_size=config.intermediate_size, intermediate_act_fn=get_activation(config.hidden_act), hidden_dropout_prob=config.hidden_dropout_prob, attention_probs_dropout_prob=config.attention_probs_dropout_prob, initializer_range=config.initializer_range, do_return_all_layers=True) self.key = key self.sequence_output = self.all_encoder_layers[-1] # The "pooler" converts the encoded sequence tensor of shape # [batch_size, seq_length, hidden_size] to a tensor of shape # [batch_size, hidden_size]. This is necessary for segment-level # (or segment-pair-level) classification tasks where we need a fixed # dimensional representation of the segment. with tf.compat.v1.variable_scope("pooler"): # We "pool" the model by simply taking the hidden state corresponding # to the first token. We assume that this has been pre-trained first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) self.pooled_output = tf.keras.layers.Dense(config.hidden_size, activation=tf.keras.activations.tanh, kernel_initializer=create_initializer(config.initializer_range))(first_token_tensor)
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) q_input_ids = features["q_input_ids"] q_input_mask = features["q_input_mask"] d_input_ids = features["d_input_ids"] d_input_mask = features["d_input_mask"] input_shape = get_shape_list(q_input_ids, expected_rank=2) batch_size = input_shape[0] doc_length = model_config.max_doc_length num_docs = model_config.num_docs d_input_ids_unpacked = tf.reshape(d_input_ids, [-1, num_docs, doc_length]) d_input_mask_unpacked = tf.reshape(d_input_mask, [-1, num_docs, doc_length]) d_input_ids_flat = tf.reshape(d_input_ids_unpacked, [-1, doc_length]) d_input_mask_flat = tf.reshape(d_input_mask_unpacked, [-1, doc_length]) q_segment_ids = tf.zeros_like(q_input_ids, tf.int32) d_segment_ids = tf.zeros_like(d_input_ids_flat, tf.int32) label_ids = features["label_ids"] is_training = (mode == tf.estimator.ModeKeys.TRAIN) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) with tf.compat.v1.variable_scope(dual_model_prefix1): q_model_config = copy.deepcopy(model_config) q_model_config.max_seq_length = model_config.max_sent_length model_q = model_class( config=model_config, is_training=is_training, input_ids=q_input_ids, input_mask=q_input_mask, token_type_ids=q_segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) with tf.compat.v1.variable_scope(dual_model_prefix2): d_model_config = copy.deepcopy(model_config) d_model_config.max_seq_length = model_config.max_doc_length model_d = model_class( config=model_config, is_training=is_training, input_ids=d_input_ids_flat, input_mask=d_input_mask_flat, token_type_ids=d_segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) pooled_q = model_q.get_pooled_output() # [batch, vector_size] pooled_d_flat = model_d.get_pooled_output( ) # [batch, num_window, vector_size] pooled_d = tf.reshape(pooled_d_flat, [batch_size, num_docs, -1]) pooled_q_t = tf.expand_dims(pooled_q, 1) pooled_d_t = tf.transpose(pooled_d, [0, 2, 1]) all_logits = tf.matmul(pooled_q_t, pooled_d_t) # [batch, 1, num_window] if "hinge_all" in special_flags: apply_loss_modeing = hinge_all elif "sigmoid_all" in special_flags: apply_loss_modeing = sigmoid_all else: apply_loss_modeing = hinge_max logits, loss = apply_loss_modeing(all_logits, label_ids) pred = tf.cast(logits > 0, tf.int32) tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if train_config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn( train_config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: if "simple_optimizer" in special_flags: tf_logging.info("using simple optimizer") train_op = create_simple_optimizer(loss, train_config.learning_rate, train_config.use_tpu) else: train_op = optimization.create_optimizer_from_config( loss, train_config, tvars) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (classification_metric_fn, [pred, label_ids, is_real_example]) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "q_input_ids": q_input_ids, "d_input_ids": d_input_ids, "logits": logits } useful_inputs = ["data_id", "input_ids2", "data_ids"] for input_name in useful_inputs: if input_name in features: predictions[input_name] = features[input_name] output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument tf_logging.info("model_fn_pooling_long_things") log_features(features) input_ids = features["input_ids"] # [batch_size, seq_length] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] label_ids = tf.reshape(label_ids, [-1]) batch_size, _ = get_shape_list( input_mask) # This is not real batch_size, 2 * real_batch_size use_context = tf.ones([batch_size, 1], tf.int32) total_sequence_length = config.total_sequence_length stacked_input_ids, stacked_input_mask, stacked_segment_ids, \ = split_and_append_sep2(input_ids[:, :total_sequence_length], input_mask[:, :total_sequence_length], segment_ids[:, :total_sequence_length], total_sequence_length, config.window_size, CLS_ID, EOW_ID) if "focus_mask" in features: focus_mask = features["focus_mask"] _, stacked_focus_mask, _, \ = split_and_append_sep2(input_ids[:, :total_sequence_length], focus_mask[:, :total_sequence_length], segment_ids[:, :total_sequence_length], total_sequence_length, config.window_size, CLS_ID, EOW_ID) features["focus_mask"] = r3to2(stacked_focus_mask) batch_size, num_seg, seq_len = get_shape_list2(stacked_input_ids) input_ids_2d = r3to2(stacked_input_ids) input_mask_2d = r3to2(stacked_input_mask) segment_ids_2d = r3to2(stacked_segment_ids) is_training = (mode == tf.estimator.ModeKeys.TRAIN) if "feed_features" in special_flags: model = model_class( config=config, is_training=is_training, input_ids=input_ids_2d, input_mask=input_mask_2d, token_type_ids=segment_ids_2d, use_one_hot_embeddings=train_config.use_one_hot_embeddings, features=features, ) else: model = model_class( config=config, is_training=is_training, input_ids=input_ids_2d, input_mask=input_mask_2d, token_type_ids=segment_ids_2d, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) sequence_output_2d = model.get_sequence_output() pooled_output = model.get_pooled_output() if is_training: pooled_output = dropout(pooled_output, 0.1) pooled_output_3d = tf.reshape(pooled_output, [batch_size, num_seg, -1]) sequence_output_3d = tf.reshape(sequence_output_2d, [batch_size, num_seg, seq_len, -1]) logits = pooling_modeling(config.option_name, train_config.num_classes, pooled_output_3d, sequence_output_3d) loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=label_ids) loss = tf.reduce_mean(input_tensor=loss_arr) tvars = tf.compat.v1.trainable_variables() if train_config.init_checkpoint: initialized_variable_names, init_fn = classification_model_fn.get_init_fn( train_config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer_from_config( loss, train_config) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: output_spec = TPUEstimatorSpec(mode=model, loss=loss, eval_metrics=None, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.PREDICT: predictions = {"input_ids": input_ids, "logits": logits} if "data_id" in features: predictions['data_id'] = features['data_id'] output_spec = TPUEstimatorSpec(mode=model, loss=loss, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] next_sentence_labels = features["next_sentence_labels"] initializer = tf.compat.v1.truncated_normal_initializer(stddev=0.02) vocab_size = 40000 embedding_size = 512 embedding_table = tf.compat.v1.get_variable( name="embedding", shape=[vocab_size, embedding_size], initializer=initializer) input_shape = get_shape_list(input_ids) flat_input_ids = tf.reshape(input_ids, [-1]) one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) output = tf.matmul(one_hot_input_ids, embedding_table) output = tf.reshape(output, input_shape + [embedding_size]) t_list = [] n_of_t = 10 for j in range(n_of_t): t_list.append(output + j) dense = tf.keras.layers.Dense(embedding_size, kernel_initializer=initializer, name="MDense") if modeling == "B": dense_list = [] for j in range(n_of_t): dense_list.append( tf.keras.layers.Dense(embedding_size, kernel_initializer=initializer, name="MDense_{}".format(j))) for i in range(20): if modeling == "A": t = tf.stack(t_list, 0) with tf.compat.v1.variable_scope("scope_A", reuse=i > 0): t = dense(t) t = tf.nn.dropout(t, rate=0.5) t_0 = 0 for j in range(1, n_of_t): t_0 += t[j] new_t_list = [t_0] for j in range(1, n_of_t): new_t_list.append(t[j]) t_list = new_t_list else: with tf.compat.v1.variable_scope("scope_B", reuse=i > 0): temp_t = [] for j in range(n_of_t): t = dense_list[j](t_list[j]) t = tf.nn.dropout(t, rate=0.5) temp_t.append(t) t_0 = 0 for j in range(1, n_of_t): t_0 += temp_t[j] new_t_list = [t_0] for j in range(1, n_of_t): new_t_list.append(temp_t[j]) t_list = new_t_list t = t_list[0] total_loss = tf.reduce_mean(t) for t in tf.compat.v1.trainable_variables(): print(t) train_op = create_optimizer(total_loss, 1e-4, 1000, 1000, True) scaffold_fn = None output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, train_op=train_op, loss=total_loss, training_hooks=[OomReportingHook()], scaffold_fn=scaffold_fn) return output_spec
def attention_layer_w_ext(from_tensor, to_tensor, attention_mask=None, num_attention_heads=1, size_per_head=512, ext_slice=None, # [Num_tokens, n_items, hidden_dim] query_act=None, key_act=None, value_act=None, attention_probs_dropout_prob=0.0, initializer_range=0.02, do_return_2d_tensor=False, batch_size=None, from_seq_length=None, to_seq_length=None): """Performs multi-headed attention from `from_tensor` to `to_tensor`. This is an implementation of multi-headed attention based on "Attention is all you Need". If `from_tensor` and `to_tensor` are the same, then this is self-attention. Each timestep in `from_tensor` attends to the corresponding sequence in `to_tensor`, and returns a fixed-with vector. This function first projects `from_tensor` into a "query" tensor and `to_tensor` into "key" and "value" tensors. These are (effectively) a list of tensors of length `num_attention_heads`, where each tensor is of shape [batch_size, seq_length, size_per_head]. Then, the query and key tensors are dot-producted and scaled. These are softmaxed to obtain attention probabilities. The value tensors are then interpolated by these probabilities, then concatenated back to a single tensor and returned. In practice, the multi-headed attention are done with transposes and reshapes rather than actual separate tensors. Args: from_tensor: float Tensor of shape [batch_size, from_seq_length, from_width]. to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. attention_mask: (optional) int32 Tensor of shape [batch_size, from_seq_length, to_seq_length]. The values should be 1 or 0. The attention scores will effectively be set to -infinity for any positions in the mask that are 0, and will be unchanged for positions that are 1. num_attention_heads: int. Number of attention heads. size_per_head: int. Size of each attention head. query_act: (optional) Activation function for the query transform. key_act: (optional) Activation function for the key transform. value_act: (optional) Activation function for the value transform. attention_probs_dropout_prob: (optional) float. Dropout probability of the attention probabilities. initializer_range: float. Range of the weight initializer. do_return_2d_tensor: bool. If True, the output will be of shape [batch_size * from_seq_length, num_attention_heads * size_per_head]. If False, the output will be of shape [batch_size, from_seq_length, num_attention_heads * size_per_head]. batch_size: (Optional) int. If the input is 2D, this might be the batch size of the 3D version of the `from_tensor` and `to_tensor`. from_seq_length: (Optional) If the input is 2D, this might be the seq length of the 3D version of the `from_tensor`. to_seq_length: (Optional) If the input is 2D, this might be the seq length of the 3D version of the `to_tensor`. Returns: float Tensor of shape [batch_size, from_seq_length, num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is true, this will be of shape [batch_size * from_seq_length, num_attention_heads * size_per_head]). Raises: ValueError: Any of the arguments or tensor shapes are invalid. """ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, seq_length, width): output_tensor = tf.reshape( input_tensor, [batch_size, seq_length, num_attention_heads, width]) output_tensor = tf.transpose(a=output_tensor, perm=[0, 2, 1, 3]) return output_tensor from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) if len(from_shape) != len(to_shape): raise ValueError( "The rank of `from_tensor` must match the rank of `to_tensor`.") if len(from_shape) == 3: batch_size = from_shape[0] from_seq_length = from_shape[1] to_seq_length = to_shape[1] elif len(from_shape) == 2: if (batch_size is None or from_seq_length is None or to_seq_length is None): raise ValueError( "When passing in rank 2 tensors to attention_layer, the values " "for `batch_size`, `from_seq_length`, and `to_seq_length` " "must all be specified.") # Scalar dimensions referenced here: # B = batch size (number of sequences) # F = `from_tensor` sequence length # T = `to_tensor` sequence length # N = `num_attention_heads` # H = `size_per_head` from_tensor_2d = reshape_to_matrix(from_tensor) to_tensor_2d = reshape_to_matrix(to_tensor) def get_ext_slice(idx): return ext_slice[:, idx, :] print("from_tensor_2d ", from_tensor_2d.shape) query_in = from_tensor_2d + get_ext_slice(EXT_QUERY_IN) query_in = from_tensor_2d # `query_layer` = [B*F, N*H] query_layer = tf.keras.layers.Dense( num_attention_heads * size_per_head, activation=query_act, name="query", kernel_initializer=create_initializer(initializer_range))(query_in) query_layer = query_layer + get_ext_slice(EXT_QUERY_OUT) key_in = to_tensor_2d key_in = to_tensor_2d + get_ext_slice(EXT_KEY_IN) # `key_layer` = [B*T, N*H] key_layer = tf.keras.layers.Dense( num_attention_heads * size_per_head, activation=key_act, name="key", kernel_initializer=create_initializer(initializer_range))(key_in) key_layer = key_layer + get_ext_slice(EXT_KEY_OUT) value_in = to_tensor_2d value_in = to_tensor_2d + get_ext_slice(EXT_VALUE_IN) # `value_layer` = [B*T, N*H] value_layer = tf.keras.layers.Dense( num_attention_heads * size_per_head, activation=value_act, name="value", kernel_initializer=create_initializer(initializer_range))(value_in) value_layer = value_layer + get_ext_slice(EXT_VALUE_OUT) # `query_layer` = [B, N, F, H] query_layer = transpose_for_scores(query_layer, batch_size, num_attention_heads, from_seq_length, size_per_head) # `key_layer` = [B, N, T, H] key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, to_seq_length, size_per_head) # Take the dot product between "query" and "key" to get the raw # attention scores. # `attention_scores` = [B, N, F, T] attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) attention_scores = tf.multiply(attention_scores, 1.0 / math.sqrt(float(size_per_head))) if attention_mask is not None: # `attention_mask` = [B, 1, F, T] attention_mask = tf.expand_dims(attention_mask, axis=[1]) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_scores += adder # Normalize the attention scores to probabilities. # `attention_probs` = [B, N, F, T] attention_probs = tf.nn.softmax(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = dropout(attention_probs, attention_probs_dropout_prob) # `value_layer` = [B, T, N, H] value_layer = tf.reshape( value_layer, [batch_size, to_seq_length, num_attention_heads, size_per_head]) # `value_layer` = [B, N, T, H] value_layer = tf.transpose(a=value_layer, perm=[0, 2, 1, 3]) # `context_layer` = [B, N, F, H] context_layer = tf.matmul(attention_probs, value_layer) # `context_layer` = [B, F, N, H] context_layer = tf.transpose(a=context_layer, perm=[0, 2, 1, 3]) if do_return_2d_tensor: # `context_layer` = [B*F, N*V] context_layer = tf.reshape( context_layer, [batch_size * from_seq_length, num_attention_heads * size_per_head]) else: # `context_layer` = [B, F, N*V] context_layer = tf.reshape( context_layer, [batch_size, from_seq_length, num_attention_heads * size_per_head]) return context_layer
def transformer_model(input_tensor, attention_mask=None, input_mask=None, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, mr_num_route=10, intermediate_size=3072, intermediate_act_fn=gelu, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, initializer_range=0.02, is_training=True, do_return_all_layers=False): """Multi-headed, multi-layer Transformer from "Attention is All You Need". This is almost an exact implementation of the original Transformer encoder. See the original paper: https://arxiv.org/abs/1706.03762 Also see: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py Args: input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, seq_length], with 1 for positions that can be attended to and 0 in positions that should not be. hidden_size: int. Hidden size of the Transformer. num_hidden_layers: int. Number of layers (blocks) in the Transformer. num_attention_heads: int. Number of attention heads in the Transformer. intermediate_size: int. The size of the "intermediate" (a.k.a., feed forward) layer. intermediate_act_fn: function. The non-linear activation function to apply to the output of the intermediate/feed-forward layer. hidden_dropout_prob: float. Dropout probability for the hidden layers. attention_probs_dropout_prob: float. Dropout probability of the attention probabilities. initializer_range: float. Range of the initializer (stddev of truncated normal). do_return_all_layers: Whether to also return all layers or just the final layer. Returns: float Tensor of shape [batch_size, seq_length, hidden_size], the final hidden layer of the Transformer. Raises: ValueError: A Tensor shape or parameter is invalid. """ if hidden_size % num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (hidden_size, num_attention_heads)) attention_head_size = int(hidden_size / num_attention_heads) input_shape = get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape[0] seq_length = input_shape[1] input_width = input_shape[2] initializer = create_initializer(initializer_range) ext_tensor = tf.compat.v1.get_variable("ext_tensor", shape=[num_hidden_layers, mr_num_route, EXT_SIZE ,hidden_size], initializer=initializer, ) ext_tensor_inter = tf.compat.v1.get_variable("ext_tensor_inter", shape=[num_hidden_layers, mr_num_route, intermediate_size], initializer=initializer, ) # The Transformer performs sum residuals on all layers so the input needs # to be the same as the hidden size. if input_width != hidden_size: raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % (input_width, hidden_size)) # We keep the representation as a 2D tensor to avoid re-shaping it back and # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on # the GPU/CPU but may not be free on the TPU, so we want to minimize them to # help the optimizer. prev_output = reshape_to_matrix(input_tensor) def is_mr_layer(layer_idx): if layer_idx > 1: return True else: return False all_layer_outputs = [] for layer_idx in range(num_hidden_layers): if not is_mr_layer(layer_idx): with tf.compat.v1.variable_scope("layer_%d" % layer_idx): layer_input = prev_output with tf.compat.v1.variable_scope("attention"): attention_heads = [] with tf.compat.v1.variable_scope("self"): attention_head = attention_layer( from_tensor=layer_input, to_tensor=layer_input, attention_mask=attention_mask, num_attention_heads=num_attention_heads, size_per_head=attention_head_size, attention_probs_dropout_prob=attention_probs_dropout_prob, initializer_range=initializer_range, do_return_2d_tensor=True, batch_size=batch_size, from_seq_length=seq_length, to_seq_length=seq_length) attention_heads.append(attention_head) attention_output = None if len(attention_heads) == 1: attention_output = attention_heads[0] else: # In the case where we have other sequences, we just concatenate # them to the self-attention head before the projection. attention_output = tf.concat(attention_heads, axis=-1) # Run a linear projection of `hidden_size` then add a residual # with `layer_input`. with tf.compat.v1.variable_scope("output"): attention_output = dense(hidden_size, initializer)(attention_output) attention_output = dropout(attention_output, hidden_dropout_prob) attention_output = layer_norm(attention_output + layer_input) # The activation is only applied to the "intermediate" hidden layer. with tf.compat.v1.variable_scope("intermediate"): intermediate_output = dense(intermediate_size, initializer, activation=intermediate_act_fn)(attention_output) # Down-project back to `hidden_size` then add the residual. with tf.compat.v1.variable_scope("output"): layer_output = dense(hidden_size, initializer)(intermediate_output) layer_output = dropout(layer_output, hidden_dropout_prob) layer_output = layer_norm(layer_output + attention_output) prev_output = layer_output all_layer_outputs.append(layer_output) with tf.compat.v1.variable_scope("mr_key"): key_output = tf.keras.layers.Dense( mr_num_route, kernel_initializer=create_initializer(initializer_range))(intermediate_output) key_output = dropout(key_output, hidden_dropout_prob) if is_training: key = tf.random.categorical(key_output, 1) # [batch_size, 1] key = tf.reshape(key, [-1]) else: key = tf.math.argmax(input=key_output, axis=1) else: # Case MR layer with tf.compat.v1.variable_scope("layer_%d" % layer_idx): layer_input = prev_output ext_slice = tf.gather(ext_tensor[layer_idx], key) ext_interm_slice = tf.gather(ext_tensor_inter[layer_idx], key) print("ext_slice (batch*seq, ", ext_slice.shape) with tf.compat.v1.variable_scope("attention"): attention_heads = [] with tf.compat.v1.variable_scope("self"): attention_head = attention_layer_w_ext( from_tensor=layer_input, to_tensor=layer_input, attention_mask=attention_mask, ext_slice=ext_slice, num_attention_heads=num_attention_heads, size_per_head=attention_head_size, attention_probs_dropout_prob=attention_probs_dropout_prob, initializer_range=initializer_range, do_return_2d_tensor=True, batch_size=batch_size, from_seq_length=seq_length, to_seq_length=seq_length) attention_head = attention_head + ext_slice[:,EXT_ATT_OUT,:] attention_heads.append(attention_head) attention_output = None if len(attention_heads) == 1: attention_output = attention_heads[0] else: # In the case where we have other sequences, we just concatenate # them to the self-attention head before the projection. attention_output = tf.concat(attention_heads, axis=-1) # Run a linear projection of `hidden_size` then add a residual # with `layer_input`. with tf.compat.v1.variable_scope("output"): attention_output = dense(hidden_size, initializer)(attention_output) attention_output = dropout(attention_output, hidden_dropout_prob) attention_output = attention_output + ext_slice[:,EXT_ATT_PROJ,:] attention_output = layer_norm(attention_output + layer_input) # The activation is only applied to the "intermediate" hidden layer. with tf.compat.v1.variable_scope("intermediate"): intermediate_output = dense(intermediate_size, initializer, activation=intermediate_act_fn)(attention_output) intermediate_output = ext_interm_slice + intermediate_output # Down-project back to `hidden_size` then add the residual. with tf.compat.v1.variable_scope("output"): layer_output = dense(hidden_size, initializer)(intermediate_output) layer_output = layer_output + ext_slice[:, EXT_LAYER_OUT,:] layer_output = dropout(layer_output, hidden_dropout_prob) layer_output = layer_norm(layer_output + attention_output) prev_output = layer_output all_layer_outputs.append(layer_output) if do_return_all_layers: final_outputs = [] for layer_output in all_layer_outputs: final_output = reshape_from_matrix(layer_output, input_shape) final_outputs.append(final_output) return final_outputs, key else: final_output = reshape_from_matrix(prev_output, input_shape) return final_output, key