def replace_elements_by_indices(old, new, indices): old_shape = modeling.get_shape_list(old) batch_size = old_shape[0] seq_length = old_shape[1] flat_offsets = tf.reshape( tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) flat_positions = tf.reshape(indices + flat_offsets, [-1]) zeros = tf.zeros(tf.shape(input=flat_positions)[0], dtype=tf.int32) flat_old = tf.reshape(old, [-1]) masked_lm_mask = tf.compat.v1.sparse_to_dense(flat_positions, tf.shape(input=flat_old), zeros, default_value=1, validate_indices=True, name="masked_lm_mask") flat_old_temp = tf.multiply(flat_old, masked_lm_mask) new_temp = tf.compat.v1.sparse_to_dense(flat_positions, tf.shape(input=flat_old), new, default_value=0, validate_indices=True, name=None) updated_old = tf.reshape(flat_old_temp + new_temp, old_shape) return updated_old
def gather_indexes(sequence_tensor, positions): """Gathers the vectors at the specific positions over a minibatch.""" sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) batch_size = sequence_shape[0] seq_length = sequence_shape[1] width = sequence_shape[2] flat_offsets = tf.reshape( tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) flat_positions = tf.reshape(positions + flat_offsets, [-1]) flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width]) output_tensor = tf.gather(flat_sequence_tensor, flat_positions) return output_tensor
def gather_indexes_rank2(sequence_tensor, positions): """Gathers the vectors at the specific positions over a minibatch.""" sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=2) batch_size = sequence_shape[0] seq_length = sequence_shape[1] flat_offsets = tf.reshape( tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) flat_positions = tf.reshape(positions + flat_offsets, [-1]) flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length]) output_tensor = tf.gather(flat_sequence_tensor, flat_positions) try: output_tensor = tf.reshape(output_tensor, [batch_size, FLAGS.max_predictions_per_seq]) except: output_tensor = tf.reshape(output_tensor, [batch_size, 1]) return output_tensor
def get_discriminator_output(electra_config, sequence_tensor, whether_replaced, label_weights): sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) batch_size = sequence_shape[0] seq_length = sequence_shape[1] width = sequence_shape[2] sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width]) with tf.compat.v1.variable_scope("discriminator"): with tf.compat.v1.variable_scope("whether_replaced/predictions"): output_weights = tf.get_variable( "output_weights", shape=[1, width], initializer=modeling.create_initializer( electra_config.initializer_range)) output_bias = tf.get_variable("output_bias", shape=[1], initializer=tf.zeros_initializer()) logits = tf.matmul(sequence_tensor, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) whether_replaced = tf.cast( tf.reshape(whether_replaced, [batch_size * seq_length, 1]), tf.float32) sigmoid_cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits( labels=whether_replaced, logits=logits, name='sigmoid_cross_entropy', ) label_weights = tf.cast(tf.reshape(label_weights, [-1]), tf.float32) sigmoid_cross_entropy = tf.reshape(sigmoid_cross_entropy, [-1]) numerator = tf.reduce_sum(label_weights * sigmoid_cross_entropy) denominator = tf.reduce_sum(label_weights) + 1e-5 loss = numerator / denominator return (loss)
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.compat.v1.logging.info("*** Features ***") for name in sorted(features.keys()): tf.compat.v1.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] batch_size = modeling.get_shape_list(input_ids)[0] #batch_size seq_length = modeling.get_shape_list(input_ids)[1] #seq_length #[B, 20] masked_lm_positions = tf.constant([ sorted( random.sample(range(1, FLAGS.max_seq_length - 2), FLAGS.max_predictions_per_seq)) for i in range(batch_size) ]) #[20*B] masks_list = tf.constant([MASK_ID] * (FLAGS.max_predictions_per_seq * batch_size)) #[B, 20] masked_lm_weights = tf.multiply( tf.ones(modeling.get_shape_list(masked_lm_positions)), tf.cast(gather_indexes_rank2(input_mask, masked_lm_positions), tf.float32)) #[B, S] masked_input_ids = replace_elements_by_indices(input_ids, masks_list, masked_lm_positions) masked_input_ids = tf.multiply(masked_input_ids, input_mask) #[B, 20] masked_lm_ids = gather_indexes_rank2(input_ids, masked_lm_positions) is_training = (mode == tf.estimator.ModeKeys.TRAIN) generator = modeling.Generator( config=electra_config, is_training=is_training, input_ids=masked_input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=use_one_hot_embeddings) (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_logits) = get_masked_lm_output(electra_config, generator.get_sequence_output(), generator.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) masked_lm_predictions = temperature_sampling(masked_logits, FLAGS.temperature) masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) masked_lm_predictions = tf.reshape(masked_lm_predictions, [-1]) diff = masked_lm_predictions - masked_lm_ids # [B*20] zero = tf.constant(0, dtype=tf.int32) #!!!!ERROR!!! fixed diff_cast = tf.cast(tf.not_equal(diff, zero), tf.int32) zeros = tf.zeros(modeling.get_shape_list(input_ids), dtype=tf.int32) whether_replaced = replace_elements_by_indices(zeros, diff_cast, masked_lm_positions) whether_replaced = tf.multiply(whether_replaced, input_mask) input_ids_for_discriminator = replace_elements_by_indices( masked_input_ids, masked_lm_predictions, masked_lm_positions) input_ids_for_discriminator = tf.multiply(input_ids_for_discriminator, input_mask) discriminator = modeling.Discriminator( config=electra_config, is_training=is_training, input_ids=input_ids_for_discriminator, input_mask=input_mask, train_pooler=False, token_type_ids=segment_ids, use_one_hot_embeddings=use_one_hot_embeddings) (disc_loss) = get_discriminator_output( electra_config, discriminator.get_sequence_output(), whether_replaced, input_mask) model_summary() tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.compat.v1.train.init_from_checkpoint( init_checkpoint, assignment_map) return tf.compat.v1.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.compat.v1.train.init_from_checkpoint( init_checkpoint, assignment_map) """ tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) """ output_spec = None total_loss = masked_lm_loss + FLAGS.disc_loss_weight * disc_loss if mode == tf.estimator.ModeKeys.TRAIN: ''' gen_train_op = optimization.create_optimizer( loss=masked_lm_loss, init_lr=learning_rate, num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, use_tpu=use_tpu, weight_decay=0.01, part='gen' ) disc_train_op = optimization.create_optimizer( loss=disc_loss, init_lr=learning_rate, num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, use_tpu=use_tpu, weight_decay=0.01, part='disc' ) ''' if FLAGS.optimizer == 'lamb': train_op = optimization.create_lamb_optimizer( loss=total_loss, init_lr=learning_rate, total_num_train_steps=FLAGS.total_num_train_steps, num_warmup_steps=num_warmup_steps, use_tpu=use_tpu, weight_decay=0.01, ) elif FLAGS.optimizer == 'adam': train_op = optimization.create_adam_optimizer( loss=total_loss, init_lr=learning_rate, total_num_train_steps=FLAGS.total_num_train_steps, num_warmup_steps=num_warmup_steps, use_tpu=use_tpu, weight_decay=0.01, ) else: print(FLAGS.optimizer, 'does not exist.') sys.exit() output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn, ) """ flops = tf.profiler.profile( tf.get_default_graph(), options=tf.profiler.ProfileOptionBuilder.float_operation()) print(flops.total_float_ops, '\n\n\n') sys.exit() """ return output_spec