示例#1
0
def replace_elements_by_indices(old, new, indices):
    old_shape = modeling.get_shape_list(old)
    batch_size = old_shape[0]
    seq_length = old_shape[1]

    flat_offsets = tf.reshape(
        tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
    flat_positions = tf.reshape(indices + flat_offsets, [-1])

    zeros = tf.zeros(tf.shape(input=flat_positions)[0], dtype=tf.int32)
    flat_old = tf.reshape(old, [-1])

    masked_lm_mask = tf.compat.v1.sparse_to_dense(flat_positions,
                                                  tf.shape(input=flat_old),
                                                  zeros,
                                                  default_value=1,
                                                  validate_indices=True,
                                                  name="masked_lm_mask")
    flat_old_temp = tf.multiply(flat_old, masked_lm_mask)
    new_temp = tf.compat.v1.sparse_to_dense(flat_positions,
                                            tf.shape(input=flat_old),
                                            new,
                                            default_value=0,
                                            validate_indices=True,
                                            name=None)
    updated_old = tf.reshape(flat_old_temp + new_temp, old_shape)

    return updated_old
示例#2
0
def gather_indexes(sequence_tensor, positions):
    """Gathers the vectors at the specific positions over a minibatch."""
    sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
    batch_size = sequence_shape[0]
    seq_length = sequence_shape[1]
    width = sequence_shape[2]

    flat_offsets = tf.reshape(
        tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
    flat_positions = tf.reshape(positions + flat_offsets, [-1])
    flat_sequence_tensor = tf.reshape(sequence_tensor,
                                      [batch_size * seq_length, width])
    output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
    return output_tensor
示例#3
0
def gather_indexes_rank2(sequence_tensor, positions):
    """Gathers the vectors at the specific positions over a minibatch."""
    sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=2)
    batch_size = sequence_shape[0]
    seq_length = sequence_shape[1]

    flat_offsets = tf.reshape(
        tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
    flat_positions = tf.reshape(positions + flat_offsets, [-1])
    flat_sequence_tensor = tf.reshape(sequence_tensor,
                                      [batch_size * seq_length])
    output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
    try:
        output_tensor = tf.reshape(output_tensor,
                                   [batch_size, FLAGS.max_predictions_per_seq])
    except:
        output_tensor = tf.reshape(output_tensor, [batch_size, 1])
    return output_tensor
示例#4
0
def get_discriminator_output(electra_config, sequence_tensor, whether_replaced,
                             label_weights):
    sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
    batch_size = sequence_shape[0]
    seq_length = sequence_shape[1]
    width = sequence_shape[2]

    sequence_tensor = tf.reshape(sequence_tensor,
                                 [batch_size * seq_length, width])

    with tf.compat.v1.variable_scope("discriminator"):
        with tf.compat.v1.variable_scope("whether_replaced/predictions"):
            output_weights = tf.get_variable(
                "output_weights",
                shape=[1, width],
                initializer=modeling.create_initializer(
                    electra_config.initializer_range))
            output_bias = tf.get_variable("output_bias",
                                          shape=[1],
                                          initializer=tf.zeros_initializer())

            logits = tf.matmul(sequence_tensor,
                               output_weights,
                               transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)

            whether_replaced = tf.cast(
                tf.reshape(whether_replaced, [batch_size * seq_length, 1]),
                tf.float32)
            sigmoid_cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=whether_replaced,
                logits=logits,
                name='sigmoid_cross_entropy',
            )

            label_weights = tf.cast(tf.reshape(label_weights, [-1]),
                                    tf.float32)
            sigmoid_cross_entropy = tf.reshape(sigmoid_cross_entropy, [-1])

            numerator = tf.reduce_sum(label_weights * sigmoid_cross_entropy)
            denominator = tf.reduce_sum(label_weights) + 1e-5
            loss = numerator / denominator

    return (loss)
示例#5
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.compat.v1.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.compat.v1.logging.info("  name = %s, shape = %s" %
                                      (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]

        batch_size = modeling.get_shape_list(input_ids)[0]  #batch_size
        seq_length = modeling.get_shape_list(input_ids)[1]  #seq_length

        #[B, 20]
        masked_lm_positions = tf.constant([
            sorted(
                random.sample(range(1, FLAGS.max_seq_length - 2),
                              FLAGS.max_predictions_per_seq))
            for i in range(batch_size)
        ])
        #[20*B]
        masks_list = tf.constant([MASK_ID] *
                                 (FLAGS.max_predictions_per_seq * batch_size))
        #[B, 20]
        masked_lm_weights = tf.multiply(
            tf.ones(modeling.get_shape_list(masked_lm_positions)),
            tf.cast(gather_indexes_rank2(input_mask, masked_lm_positions),
                    tf.float32))

        #[B, S]
        masked_input_ids = replace_elements_by_indices(input_ids, masks_list,
                                                       masked_lm_positions)
        masked_input_ids = tf.multiply(masked_input_ids, input_mask)

        #[B, 20]
        masked_lm_ids = gather_indexes_rank2(input_ids, masked_lm_positions)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        generator = modeling.Generator(
            config=electra_config,
            is_training=is_training,
            input_ids=masked_input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=use_one_hot_embeddings)

        (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
         masked_logits) = get_masked_lm_output(electra_config,
                                               generator.get_sequence_output(),
                                               generator.get_embedding_table(),
                                               masked_lm_positions,
                                               masked_lm_ids,
                                               masked_lm_weights)

        masked_lm_predictions = temperature_sampling(masked_logits,
                                                     FLAGS.temperature)

        masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
        masked_lm_predictions = tf.reshape(masked_lm_predictions, [-1])
        diff = masked_lm_predictions - masked_lm_ids  # [B*20]

        zero = tf.constant(0, dtype=tf.int32)
        #!!!!ERROR!!! fixed
        diff_cast = tf.cast(tf.not_equal(diff, zero), tf.int32)

        zeros = tf.zeros(modeling.get_shape_list(input_ids), dtype=tf.int32)
        whether_replaced = replace_elements_by_indices(zeros, diff_cast,
                                                       masked_lm_positions)
        whether_replaced = tf.multiply(whether_replaced, input_mask)

        input_ids_for_discriminator = replace_elements_by_indices(
            masked_input_ids, masked_lm_predictions, masked_lm_positions)
        input_ids_for_discriminator = tf.multiply(input_ids_for_discriminator,
                                                  input_mask)

        discriminator = modeling.Discriminator(
            config=electra_config,
            is_training=is_training,
            input_ids=input_ids_for_discriminator,
            input_mask=input_mask,
            train_pooler=False,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=use_one_hot_embeddings)

        (disc_loss) = get_discriminator_output(
            electra_config, discriminator.get_sequence_output(),
            whether_replaced, input_mask)

        model_summary()

        tvars = tf.compat.v1.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.compat.v1.train.init_from_checkpoint(
                        init_checkpoint, assignment_map)
                    return tf.compat.v1.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.compat.v1.train.init_from_checkpoint(
                    init_checkpoint, assignment_map)
        """
        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)
        """
        output_spec = None
        total_loss = masked_lm_loss + FLAGS.disc_loss_weight * disc_loss
        if mode == tf.estimator.ModeKeys.TRAIN:
            '''
            gen_train_op = optimization.create_optimizer(
                loss=masked_lm_loss,
                init_lr=learning_rate,
                num_train_steps=num_train_steps,
                num_warmup_steps=num_warmup_steps,
                use_tpu=use_tpu,
                weight_decay=0.01,
                part='gen'
            )

            disc_train_op = optimization.create_optimizer(
                loss=disc_loss,
                init_lr=learning_rate,
                num_train_steps=num_train_steps,
                num_warmup_steps=num_warmup_steps,
                use_tpu=use_tpu,
                weight_decay=0.01,
                part='disc'
            )
            '''
            if FLAGS.optimizer == 'lamb':
                train_op = optimization.create_lamb_optimizer(
                    loss=total_loss,
                    init_lr=learning_rate,
                    total_num_train_steps=FLAGS.total_num_train_steps,
                    num_warmup_steps=num_warmup_steps,
                    use_tpu=use_tpu,
                    weight_decay=0.01,
                )
            elif FLAGS.optimizer == 'adam':
                train_op = optimization.create_adam_optimizer(
                    loss=total_loss,
                    init_lr=learning_rate,
                    total_num_train_steps=FLAGS.total_num_train_steps,
                    num_warmup_steps=num_warmup_steps,
                    use_tpu=use_tpu,
                    weight_decay=0.01,
                )
            else:
                print(FLAGS.optimizer, 'does not exist.')
                sys.exit()
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn,
            )
            """
            flops = tf.profiler.profile(
                tf.get_default_graph(),
                options=tf.profiler.ProfileOptionBuilder.float_operation())
            print(flops.total_float_ops, '\n\n\n')
            sys.exit()
            """
        return output_spec