def __init__(self, config, is_training, use_one_hot_embeddings=True, features=None, scope=None): super(MES_sel, self).__init__() input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] unit_length = config.max_seq_length d_seq_length = config.max_d_seq_length num_window = int(d_seq_length / unit_length) batch_size, _ = get_shape_list2(input_ids) # [Batch, num_window, unit_seq_length] stacked_input_ids, stacked_input_mask, stacked_segment_ids = split_input( input_ids, input_mask, segment_ids, d_seq_length, unit_length) with tf.compat.v1.variable_scope(dual_model_prefix1): model = BertModel( config=config, is_training=is_training, input_ids=r3to2(stacked_input_ids), input_mask=r3to2(stacked_input_mask), token_type_ids=r3to2(stacked_segment_ids), use_one_hot_embeddings=use_one_hot_embeddings, ) def r2to3(arr): return tf.reshape(arr, [batch_size, num_window, -1]) # [Batch, num_window, window_length, hidden_size] pooled = model.get_pooled_output() logits_2d = tf.keras.layers.Dense(2, name="cls_dense")(pooled) # logits_3d = r2to3(logits_2d) label_ids_repeat = tf.tile(label_ids, [1, num_window]) loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits_3d, labels=label_ids_repeat) layer1_loss = tf.reduce_mean(loss_arr) probs = tf.nn.softmax(logits_3d)[:, :, 1] # [batch_size, num_window] # Probabilistic selection def select_seg(stacked_input_ids, indices): # indices : [batch_size, 1] return tf.gather(stacked_input_ids, indices, axis=1, batch_dims=1) max_seg = tf.argmax(probs, axis=1) input_ids = select_seg(stacked_input_ids, max_seg) input_mask = select_seg(stacked_input_mask, max_seg) segment_ids = select_seg(stacked_segment_ids, max_seg) with tf.compat.v1.variable_scope(dual_model_prefix2): model = BertModel( config=config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=use_one_hot_embeddings, ) logits = tf.keras.layers.Dense(2, name="cls_dense")( model.get_pooled_output()) self.logits = logits label_ids = tf.reshape(label_ids, [-1]) loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=label_ids) layer2_loss = tf.reduce_mean(loss_arr) alpha = self.get_alpha(config.decay_steps) loss = alpha * layer1_loss + layer2_loss self.loss = loss
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] label_ids = tf.reshape(label_ids, [-1]) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) model_1 = BertModel( config=model_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) pooled = model_1.get_pooled_output() if is_training: pooled = dropout(pooled, 0.1) logits = tf.keras.layers.Dense(train_config.num_classes, name="cls_dense")(pooled) loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=label_ids) loss = tf.reduce_mean(input_tensor=loss_arr) tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if train_config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn( train_config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec global_step = tf.compat.v1.train.get_or_create_global_step() output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: tvars = None train_op = create_optimizer_from_config(loss, train_config, tvars) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (classification_metric_fn, [logits, label_ids, is_real_example]) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "input_ids": input_ids, "label_ids": label_ids, "logits": logits, } if "data_id" in features: predictions['data_id'] = features['data_id'] output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def __init__(self, config, is_training, use_one_hot_embeddings=True, features=None, scope=None): super(MES_hinge, self).__init__() alpha = config.alpha input_ids1 = features["input_ids1"] input_mask1 = features["input_mask1"] segment_ids1 = features["segment_ids1"] input_ids2 = features["input_ids2"] input_mask2 = features["input_mask2"] segment_ids2 = features["segment_ids2"] input_ids = tf.concat([input_ids1, input_ids2], axis=0) input_mask = tf.concat([input_mask1, input_mask2], axis=0) segment_ids = tf.concat([segment_ids1, segment_ids2], axis=0) unit_length = config.max_seq_length d_seq_length = config.max_d_seq_length num_window = int(d_seq_length / unit_length) batch_size, _ = get_shape_list2(input_ids) # [Batch, num_window, unit_seq_length] stacked_input_ids, stacked_input_mask, stacked_segment_ids = split_input( input_ids, input_mask, segment_ids, d_seq_length, unit_length) # Ignore the window if # 1. The window is not first window and # 1.1 All input_mask is 0 # 1.2 Content is too short, number of document tokens (other than query tokens) < 10 # [Batch, num_window] is_first_window = tf.concat([ tf.ones([batch_size, 1], tf.bool), tf.zeros([batch_size, num_window - 1], tf.bool) ], axis=1) num_content_tokens = tf.reduce_sum(stacked_segment_ids, 2) has_enough_evidence = tf.less(10, num_content_tokens) is_valid_window = tf.logical_or(is_first_window, has_enough_evidence) is_valid_window_mask = tf.cast(is_valid_window, tf.float32) self.is_first_window = is_first_window self.num_content_tokens = num_content_tokens self.has_enough_evidence = has_enough_evidence self.is_valid_window = is_valid_window with tf.compat.v1.variable_scope(dual_model_prefix1): model = BertModel( config=config, is_training=is_training, input_ids=r3to2(stacked_input_ids), input_mask=r3to2(stacked_input_mask), token_type_ids=r3to2(stacked_segment_ids), use_one_hot_embeddings=use_one_hot_embeddings, ) def r2to3(arr): return tf.reshape(arr, [batch_size, num_window, -1]) # [Batch, num_window, window_length, hidden_size] pooled = model.get_pooled_output() logits_2d = tf.keras.layers.Dense(1, name="cls_dense")(pooled) # logits_3d = r2to3(logits_2d) # [ batch, num_window, 1] pair_logits_layer1 = tf.reshape(logits_3d, [2, -1, num_window]) y_diff = pair_logits_layer1[0, :, :] - pair_logits_layer1[1, :, :] loss_arr = tf.maximum(1.0 - y_diff, 0) is_valid_window_pair = tf.reshape(is_valid_window, [2, -1, num_window]) is_valid_window_and = tf.logical_and(is_valid_window_pair[0, :, :], is_valid_window_pair[1, :, :]) is_valid_window_paired_mask = tf.cast(is_valid_window_and, tf.float32) loss_arr = loss_arr * is_valid_window_paired_mask layer1_loss = tf.reduce_mean(loss_arr) probs = tf.nn.softmax(logits_3d)[:, :, 0] # [batch_size, num_window] # Probabilistic selection def select_seg(stacked_input_ids, indices): # indices : [batch_size, 1] return tf.gather(stacked_input_ids, indices, axis=1, batch_dims=1) valid_probs = probs * is_valid_window_mask max_seg = tf.argmax(valid_probs, axis=1) input_ids = select_seg(stacked_input_ids, max_seg) input_mask = select_seg(stacked_input_mask, max_seg) segment_ids = select_seg(stacked_segment_ids, max_seg) with tf.compat.v1.variable_scope(dual_model_prefix2): model = BertModel( config=config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=use_one_hot_embeddings, ) logits = tf.keras.layers.Dense(1, name="cls_dense")( model.get_pooled_output()) pair_logits = tf.reshape(logits, [2, -1]) y_diff2 = pair_logits[0, :] - pair_logits[1, :] loss_arr = tf.maximum(1.0 - y_diff2, 0) self.logits = logits layer2_loss = tf.reduce_mean(loss_arr) loss = alpha * layer1_loss + layer2_loss self.loss = loss
def __init__(self, config, is_training, use_one_hot_embeddings=True, features=None, scope=None): super(MES_pred, self).__init__() alpha = config.alpha input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] unit_length = config.max_seq_length d_seq_length = config.max_d_seq_length num_window = int(d_seq_length / unit_length) batch_size, _ = get_shape_list2(input_ids) # [Batch, num_window, unit_seq_length] stacked_input_ids, stacked_input_mask, stacked_segment_ids = split_input( input_ids, input_mask, segment_ids, d_seq_length, unit_length) # Ignore the window if # 1. The window is not first window and # 1.1 All input_mask is 0 # 1.2 Content is too short, number of document tokens (other than query tokens) < 10 # [Batch, num_window] is_first_window = tf.concat([ tf.ones([batch_size, 1], tf.bool), tf.zeros([batch_size, num_window - 1], tf.bool) ], axis=1) num_content_tokens = tf.reduce_sum(stacked_segment_ids, 2) has_enough_evidence = tf.less(10, num_content_tokens) is_valid_window = tf.logical_or(is_first_window, has_enough_evidence) is_valid_window_mask = tf.cast(is_valid_window, tf.float32) self.is_first_window = is_first_window self.num_content_tokens = num_content_tokens self.has_enough_evidence = has_enough_evidence self.is_valid_window = is_valid_window with tf.compat.v1.variable_scope(dual_model_prefix1): model = BertModel( config=config, is_training=is_training, input_ids=r3to2(stacked_input_ids), input_mask=r3to2(stacked_input_mask), token_type_ids=r3to2(stacked_segment_ids), use_one_hot_embeddings=use_one_hot_embeddings, ) def r2to3(arr): return tf.reshape(arr, [batch_size, num_window, -1]) # [Batch, num_window, window_length, hidden_size] pooled = model.get_pooled_output() logits_2d = tf.keras.layers.Dense(1, name="cls_dense")(pooled) # logits_3d = r2to3(logits_2d) # [ batch, num_window, 1] probs = tf.nn.softmax(logits_3d)[:, :, 0] # [batch_size, num_window] # Probabilistic selection def select_seg(stacked_input_ids, indices): # indices : [batch_size, 1] return tf.gather(stacked_input_ids, indices, axis=1, batch_dims=1) valid_probs = probs * is_valid_window_mask with tf.compat.v1.variable_scope(dual_model_prefix2): model = BertModel( config=config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=use_one_hot_embeddings, ) logits = tf.keras.layers.Dense(1, name="cls_dense")( model.get_pooled_output()) self.logits = logits self.loss = tf.constant(0)
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] label_ids = tf.reshape(label_ids, [-1]) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) model_1 = BertModel( config=model_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) pooled = model_1.get_pooled_output() if is_training: pooled = dropout(pooled, 0.1) logits = tf.keras.layers.Dense(train_config.num_classes, name="cls_dense")(pooled) loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=label_ids) loss = tf.reduce_mean(input_tensor=loss_arr) tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if train_config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn( train_config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec global_step = tf.compat.v1.train.get_or_create_global_step() init_lr = train_config.learning_rate num_warmup_steps = train_config.num_warmup_steps num_train_steps = train_config.num_train_steps learning_rate2_const = tf.constant(value=init_lr, shape=[], dtype=tf.float32) learning_rate2_decayed = tf.compat.v1.train.polynomial_decay( learning_rate2_const, global_step, num_train_steps, end_learning_rate=0.0, power=1.0, cycle=False) if num_warmup_steps: global_steps_int = tf.cast(global_step, tf.int32) warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) global_steps_float = tf.cast(global_steps_int, tf.float32) warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) warmup_percent_done = global_steps_float / warmup_steps_float warmup_learning_rate = init_lr * warmup_percent_done is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) learning_rate = ((1.0 - is_warmup) * learning_rate2_decayed + is_warmup * warmup_learning_rate) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: tvars = None train_op = optimization.create_optimizer_from_config( loss, train_config, tvars) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (classification_metric_fn, [logits, label_ids, is_real_example]) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: def reform_scala(t): return tf.reshape(t, [1]) predictions = { "input_ids": input_ids, "label_ids": label_ids, "logits": logits, "learning_rate2_const": reform_scala(learning_rate2_const), "warmup_percent_done": reform_scala(warmup_percent_done), "warmup_learning_rate": reform_scala(warmup_learning_rate), "learning_rate": reform_scala(learning_rate), "learning_rate2_decayed": reform_scala(learning_rate2_decayed), } if "data_id" in features: predictions['data_id'] = features['data_id'] output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec