def delete_tokens(input_ids, n_trial, shift): delete_location = [] n_block_size = 1 for i in range(n_trial): st = shift + i * n_block_size ed = shift + (i + 1) * n_block_size row = [] for j in range(st, ed): row.append(j) delete_location.append(row) print(delete_location) batch_size, _ = get_shape_list2(input_ids) # [n_trial, 1] delete_location = tf.constant(delete_location, tf.int32) # [1, n_trial, 1] delete_location = tf.expand_dims(delete_location, 0) # [batch_size, n_trial, 1] delete_location = tf.tile(delete_location, [batch_size, 1, 1]) # [n_trial, batch, 1] delete_location = tf.transpose(delete_location, [1, 0, 2]) # [n_trial * batch, 1] delete_location = tf.reshape(delete_location, [batch_size * n_trial, -1]) n_input_ids = tf.tile(input_ids, [n_trial, 1]) masked_input_ids = scatter_with_batch(n_input_ids, delete_location, MASK_ID) return masked_input_ids
def sigmoid_all(all_logits, label_ids): print('all_logits', all_logits) print('logits', all_logits) batch_size, _, num_seg = get_shape_list(all_logits) lable_ids_tile = tf.cast( tf.tile(tf.expand_dims(label_ids, 2), [1, 1, num_seg]), tf.float32) print('label_ids', label_ids) losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=all_logits, labels=lable_ids_tile) loss = tf.reduce_mean(losses) probs = tf.nn.sigmoid(all_logits) logits = tf.reduce_mean(probs, axis=2) return logits, loss
def hinge_all(all_logits, label_ids): print('all_logits', all_logits) # logits = tf.reduce_max(all_logits, axis=2) print('logits', all_logits) y = tf.cast(label_ids, tf.float32) * 2 - 1 print('label_ids', label_ids) print('y', y) y_expand = tf.expand_dims(y, 2) print('y_expand') t = all_logits * y_expand losses = tf.maximum(1.0 - t, 0) loss = tf.reduce_mean(losses) logits = tf.reduce_mean(all_logits, axis=2) return logits, loss
def one_by_one_masking(input_ids, input_masks, mask_token, n_trial): batch_size, seq_length = get_batch_and_seq_length(input_ids, 2) loc_dummy = tf.cast(tf.range(0, seq_length), tf.float32) loc_dummy = tf.tile(tf.expand_dims(loc_dummy, 0), [batch_size, 1]) loc_dummy = remove_special_mask(input_ids, input_masks, loc_dummy) indices = tf.argsort(loc_dummy, axis=-1, direction='ASCENDING', stable=False, name=None) # [25, batch, 20] n_input_ids = tf.tile(input_ids, [n_trial, 1]) lm_locations = tf.reshape(indices[:, :n_trial], [-1, 1]) masked_lm_positions = lm_locations # [ batch*n_trial, max_predictions) masked_lm_ids = gather_index2d(n_input_ids, masked_lm_positions) masked_lm_weights = tf.ones_like(masked_lm_positions, dtype=tf.float32) masked_input_ids = scatter_with_batch(n_input_ids, masked_lm_positions, mask_token) return masked_input_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) q_input_ids = features["q_input_ids"] q_input_mask = features["q_input_mask"] d_input_ids = features["d_input_ids"] d_input_mask = features["d_input_mask"] input_shape = get_shape_list(q_input_ids, expected_rank=2) batch_size = input_shape[0] doc_length = model_config.max_doc_length num_docs = model_config.num_docs d_input_ids_unpacked = tf.reshape(d_input_ids, [-1, num_docs, doc_length]) d_input_mask_unpacked = tf.reshape(d_input_mask, [-1, num_docs, doc_length]) d_input_ids_flat = tf.reshape(d_input_ids_unpacked, [-1, doc_length]) d_input_mask_flat = tf.reshape(d_input_mask_unpacked, [-1, doc_length]) q_segment_ids = tf.zeros_like(q_input_ids, tf.int32) d_segment_ids = tf.zeros_like(d_input_ids_flat, tf.int32) label_ids = features["label_ids"] is_training = (mode == tf.estimator.ModeKeys.TRAIN) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) with tf.compat.v1.variable_scope(dual_model_prefix1): q_model_config = copy.deepcopy(model_config) q_model_config.max_seq_length = model_config.max_sent_length model_q = model_class( config=model_config, is_training=is_training, input_ids=q_input_ids, input_mask=q_input_mask, token_type_ids=q_segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) with tf.compat.v1.variable_scope(dual_model_prefix2): d_model_config = copy.deepcopy(model_config) d_model_config.max_seq_length = model_config.max_doc_length model_d = model_class( config=model_config, is_training=is_training, input_ids=d_input_ids_flat, input_mask=d_input_mask_flat, token_type_ids=d_segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) pooled_q = model_q.get_pooled_output() # [batch, vector_size] pooled_d_flat = model_d.get_pooled_output( ) # [batch, num_window, vector_size] pooled_d = tf.reshape(pooled_d_flat, [batch_size, num_docs, -1]) pooled_q_t = tf.expand_dims(pooled_q, 1) pooled_d_t = tf.transpose(pooled_d, [0, 2, 1]) all_logits = tf.matmul(pooled_q_t, pooled_d_t) # [batch, 1, num_window] if "hinge_all" in special_flags: apply_loss_modeing = hinge_all elif "sigmoid_all" in special_flags: apply_loss_modeing = sigmoid_all else: apply_loss_modeing = hinge_max logits, loss = apply_loss_modeing(all_logits, label_ids) pred = tf.cast(logits > 0, tf.int32) tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if train_config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn( train_config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: if "simple_optimizer" in special_flags: tf_logging.info("using simple optimizer") train_op = create_simple_optimizer(loss, train_config.learning_rate, train_config.use_tpu) else: train_op = optimization.create_optimizer_from_config( loss, train_config, tvars) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (classification_metric_fn, [pred, label_ids, is_real_example]) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "q_input_ids": q_input_ids, "d_input_ids": d_input_ids, "logits": logits } useful_inputs = ["data_id", "input_ids2", "data_ids"] for input_name in useful_inputs: if input_name in features: predictions[input_name] = features[input_name] output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] next_sentence_labels = get_dummy_next_sentence_labels(input_ids) batch_size, seq_length = get_batch_and_seq_length(input_ids, 2) n_trial = seq_length - 20 masked_input_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights \ = one_by_one_masking(input_ids, input_mask, MASK_ID, n_trial) num_classes = train_config.num_classes n_repeat = num_classes * n_trial # [ num_classes * n_trial * batch_size, seq_length] repeat_masked_input_ids = tf.tile(masked_input_ids, [num_classes, 1]) repeat_input_mask = tf.tile(input_mask, [n_repeat, 1]) repeat_segment_ids = tf.tile(segment_ids, [n_repeat, 1]) masked_lm_positions = tf.tile(masked_lm_positions, [num_classes, 1]) masked_lm_ids = tf.tile(masked_lm_ids, [num_classes, 1]) masked_lm_weights = tf.tile(masked_lm_weights, [num_classes, 1]) next_sentence_labels = tf.tile(next_sentence_labels, [n_repeat, 1]) is_training = (mode == tf.estimator.ModeKeys.TRAIN) virtual_labels_ids = tf.tile(tf.expand_dims(tf.range(num_classes), 0), [1, batch_size * n_trial]) virtual_labels_ids = tf.reshape(virtual_labels_ids, [-1, 1]) print("repeat_masked_input_ids", repeat_masked_input_ids.shape) print("repeat_input_mask", repeat_input_mask.shape) print("virtual_labels_ids", virtual_labels_ids.shape) model = BertModelWithLabelInner( config=model_config, is_training=is_training, input_ids=repeat_masked_input_ids, input_mask=repeat_input_mask, token_type_ids=repeat_segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, label_ids=virtual_labels_ids, ) (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) (next_sentence_loss, next_sentence_example_loss, next_sentence_log_probs) = get_next_sentence_output( model_config, model.get_pooled_output(), next_sentence_labels) total_loss = masked_lm_loss # loss = -log(prob) # TODO compare log prob of each label per_case_loss = tf.reshape(masked_lm_example_loss, [num_classes, -1, batch_size]) per_label_loss = tf.reduce_sum(per_case_loss, axis=1) bias = tf.zeros([3, 1]) per_label_score = tf.transpose(-per_label_loss + bias, [1, 0]) tvars = tf.compat.v1.trainable_variables() initialized_variable_names, initialized_variable_names2, init_fn\ = align_checkpoint_for_lm(tvars, train_config.checkpoint_type, train_config.init_checkpoint, train_config.second_init_checkpoint, ) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names, initialized_variable_names2) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer_from_config( total_loss, train_config) output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, training_hooks=[OomReportingHook()], scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (metric_fn_lm, [ masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, ]) output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = {"input_ids": input_ids, "logits": per_label_score} output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec