Example #1
0
def residual_mlp_layer(x_flat,
                       intermediate_size,
                       initializer_range=0.02,
                       hidden_dropout_prob=0.1):
    """
    :param x: The attention output. It should be [batch_size*seq_length, dim]
    :param intermediate_size: the hidden projection. By default this is the input_dim * 4.

    in the original GPT we would return layer_norm(x_norm + h1) rather than layer_norm(x + h1)

    :return:
    """
    batch_size_seq_length, hidden_size = get_shape_list(x_flat,
                                                        expected_rank=2)
    x_norm = layer_norm(x_flat, name='mlp_ln0')

    intermediate_output = tf.layers.dense(
        x_norm,
        intermediate_size,
        activation=gelu,
        kernel_initializer=create_initializer(initializer_range),
        name='intermediate',
    )

    output_for_residual = tf.layers.dense(
        intermediate_output,
        hidden_size,
        name='output',
        kernel_initializer=create_initializer(initializer_range))
    output_for_residual = dropout(output_for_residual, hidden_dropout_prob)

    layer_output = layer_norm(x_flat + output_for_residual, name='mlp_ln1')
    return layer_output
Example #2
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        label_ids = features["label_ids"]
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        # Create model with aux loss
        model = GroverModel(
            config=config,
            is_training=is_training,
            input_ids=input_ids,
            pad_token_id=config.pad_token_id,
            chop_off_last_token=False,
        )

        with tf.variable_scope('classification'):
            hidden_state = model.pooled_output(pool_token_id)
            if is_training:
                hidden_state = dropout(hidden_state, dropout_prob=0.1)
            logits = tf.layers.dense(hidden_state,
                                     num_labels,
                                     kernel_initializer=create_initializer(
                                         config.initializer_range),
                                     name='logits')
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids,
                                        depth=num_labels,
                                        dtype=tf.float32)
            per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                              axis=-1)
            class_loss = tf.reduce_mean(per_example_loss)

        total_loss = lm_loss_coef * model.lm_loss() + class_loss

        if is_training:
            train_op, train_metrics = optimization_adafactor.create_optimizer(
                total_loss, learning_rate, num_train_steps, num_warmup_steps,
                use_tpu)
            # tvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
            tvars = tf.trainable_variables()

            train_metrics['minibatch_cls_loss'] = class_loss
            train_metrics['minibatch_acc'] = tf.reduce_mean(
                tf.cast(
                    tf.equal(tf.argmax(logits, axis=-1, output_type=tf.int32),
                             label_ids), tf.float32))
        else:
            train_op = None
            train_metrics = {}
            tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map,
             initialized_variable_names) = get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.debug("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.debug("  name = %s, shape = %s%s", var.name, var.shape,
                             init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            if use_tpu:
                output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    train_op=train_op,
                    host_call=construct_scalar_host_call(
                        metric_dict=train_metrics,
                        model_dir=params['model_dir'],
                        prefix='training/'),
                    scaffold_fn=scaffold_fn)
            else:
                output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    train_op=train_op,
                    training_hooks=[
                        tf.train.LoggingTensorHook(
                            {'loss': tf.metrics.mean(total_loss)[1]},
                            every_n_iter=100)
                    ],
                    scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, label_ids, logits,
                          is_real_example):
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                accuracy = tf.metrics.accuracy(labels=label_ids,
                                               predictions=predictions,
                                               weights=is_real_example)
                loss = tf.metrics.mean(values=per_example_loss,
                                       weights=is_real_example)
                return {
                    "eval_accuracy": accuracy,
                    "eval_loss": loss,
                }

            eval_metrics = (metric_fn, [
                per_example_loss, label_ids, logits, is_real_example
            ])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                predictions={
                    'logits': logits,
                    'probs': tf.nn.softmax(logits, axis=-1)
                },
                scaffold_fn=scaffold_fn)
        return output_spec
Example #3
0
def attention_layer(x_flat,
                    attention_mask,
                    batch_size,
                    seq_length,
                    size_per_head=512,
                    num_attention_heads=1,
                    *,
                    cache=None,
                    initializer_range=0.02,
                    hidden_dropout_prob=0.1,
                    attention_probs_dropout_prob=0.1,
                    do_cache=False):
    """

    :param x_flat: Tensor input, should be [batch_size*seq_length, dim]
    :param attention_mask: Attention mask to use of size [seq_length, seq_length+cached_length]
    :param size_per_head: dim = size_per_head * num_attention_heads
    :param num_attention_heads:  dim = size_per_head * num_attention_heads
    :param cache: Optionally some past (cached) things of size
                [batch, 2, heads, sequence, features], where 2 is [k, v]
    :param do_cache: True if we should return cache
    :return: A new tensor of shape [batch_size, seq_length, dim]
    as well as a new cache "cached_keys_and_values" that will be of size
                                   [batch_size, 2, num_attention_heads, seq_length, dim]
    """
    batch_size_seq_length, dim = get_shape_list(x_flat, expected_rank=2)

    # Had to remove this because of generation script
    # if (batch_size_seq_length != batch_size * seq_length):
    #     raise ValueError("passed in a tensor of shape {} when batch_size={} and seq_length={}".format(
    #         (batch_size_seq_length, dim), batch_size, seq_length
    #     ))

    if dim != size_per_head * num_attention_heads:
        raise ValueError(
            "passed in a tensor of shape {} when size_per_head={} and num_attention_heads={}"
            .format((batch_size_seq_length, dim), size_per_head,
                    num_attention_heads))

    # if do_cache and past is not None:
    #     Shape will be (batch_size, 2, num_attention_heads, past_seq_length, dim)
    #     past_shape = get_shape_list(past, 5)
    #     desired_shape = (batch_size, 2, num_attention_heads, seq_length, dim)
    #     if tuple(past_shape) != desired_shape:
    #         raise ValueError(f"The shape of the cache is {past_shape} but we want {desired_shape}")

    # [ batch_size, num_attention_heads, seq_length, size_per_head]
    query = _attention_projection_and_transpose(
        x_flat,
        batch_size=batch_size,
        seq_length=seq_length,
        num_attention_heads=num_attention_heads,
        size_per_head=size_per_head,
        name='query_layer',
        initializer_range=initializer_range)
    key = _attention_projection_and_transpose(
        x_flat,
        batch_size=batch_size,
        seq_length=seq_length,
        num_attention_heads=num_attention_heads,
        size_per_head=size_per_head,
        name='key_layer',
        initializer_range=initializer_range)

    value = _attention_projection_and_transpose(
        x_flat,
        batch_size=batch_size,
        seq_length=seq_length,
        num_attention_heads=num_attention_heads,
        size_per_head=size_per_head,
        name='value_layer',
        initializer_range=initializer_range)

    # Add to cache
    cached_keys_and_values = tf.stack([key, value],
                                      axis=1) if do_cache else None

    # Things that were relevant from the cache
    if cache is not None:
        pk, pv = tf.unstack(cache, axis=1)
        key = tf.concat([pk, key], axis=-2)
        value = tf.concat([pv, value], axis=-2)

    # Multiply [batch_size, num_attention_heads, seq_length, size_per_head] with
    #          [batch_size, num_attention_heads, size_per_head, seq_length+cached_length] ->
    #          [batch_size, num_attention_heads, seq_length, seq_length+cached_length]
    attention_scores = tf.matmul(query, key, transpose_b=True)
    attention_scores = tf.multiply(attention_scores,
                                   1.0 / math.sqrt(float(size_per_head)))
    attention_scores = mask_attention_for_ltr(attention_scores, attention_mask)
    attention_probs = tf.nn.softmax(attention_scores)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    # NOPENOPENOPENOPE
    # attention_probs = factoreddropout(attention_probs, attention_probs_dropout_prob)

    # Multiply [batch_size, num_attention_heads, seq_length, seq_length+cached_length] with
    #          [batch_size, num_attention_heads, seq_length+cached_length, size_per_head] ->
    #          [batch_size, num_attention_heads, seq_length, size_per_head] ->
    context_layer = tf.matmul(attention_probs, value)

    # `context_layer` = [batch_size, seq_length, num_attention_heads, size_per_head]
    context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
    context_layer = tf.reshape(
        context_layer,
        [batch_size * seq_length, num_attention_heads * size_per_head])

    context_layer_projected = tf.layers.dense(
        context_layer,
        num_attention_heads * size_per_head,
        kernel_initializer=create_initializer(initializer_range),
        name='context_projection_layer')
    context_layer_projected = dropout(context_layer_projected,
                                      hidden_dropout_prob)

    return context_layer_projected, cached_keys_and_values