Example #1
0
 def symbols_to_logits_fn(input_symbols, i, state):
     # [batch_size, decoded_ids] to [batch_size, vocab_size]
     leng = shape_list(input_symbols)[1]
     pos_embed = encoder.vocab_size + tf.range(leng)
     pos_embed = tf.tile([pos_embed], [shape_list(input_symbols)[0], 1])
     inp = tf.pad(tf.stack([input_symbols, pos_embed], -1), [[0,0], [0, 1], [0, 0]] )
     target_feat_state = config.base_model.get_featurizer(
         inp,
         encoder=encoder,
         config=config,
         train=train,
         encoder_state=({**state["featurizer_state"], "embed_weights":embed_weights} if state else None)
     )
     output_state = language_model(
         X=inp,
         M=tf.sequence_mask(target_feat_state["eos_idx"] + 1, maxlen=leng + 1, dtype=tf.float32), # +1 because we want to predict the clf token as an eos token
         embed_weights=embed_weights[:encoder.vocab_size, :],
         config=config,
         reuse=reuse, train=train,
         hidden=target_feat_state["sequence_features"] # deal with state
     )
     if state is None:
         return output_state["logits"][:, i, :]
     else:
         return output_state["logits"][:, i, :], state
def dynamic_convolution(inp,
                        inp2=None,
                        n_heads=8,
                        kernel_size=4,
                        padding="causal"):
    batch, seq, n_channels = shape_list(inp)
    kernel_size_inc_ra = kernel_size
    if inp2 is not None:
        kernel_size_inc_ra += shape_list(inp2)[2]

    weights_linear = tf.reshape(
        linear(
            inp,
            n_heads * kernel_size_inc_ra,
            "kernel_machine",
        ),
        [batch, seq, kernel_size_inc_ra, n_heads
         ],  # batch time heads, kernel_size, 1
    )
    weights_linear = tf.nn.softmax(weights_linear, 2)

    if inp2 is not None:
        weights_ra = weights_linear[:, :, kernel_size:]
        weights_linear = weights_linear[:, :, :kernel_size]
        ra_dynamic = dynamic_conv_on_ra_out(inp2, weights_ra)
    else:
        ra_dynamic = 0.0

    if indico_ops.BUILT and tf.test.is_gpu_available():
        # There is no CPU version of this op we need to check this.
        dynamic_conv_fn = indico_ops.dynamic_convolution_op
    else:
        dynamic_conv_fn = dynamic_conv_cpu

    return dynamic_conv_fn(inp, weights_linear, padding=padding) + ra_dynamic
Example #3
0
def language_model(*, X, M, embed_weights, hidden, config, reuse=None):
    """
    A language model output and loss for the language modelling objective described in the original finetune paper.
    This language model uses weights that are tied to the input embedding.
    :param X: The raw token ids fed to the featurizer.
    :param M: A loss mask, with 1's where losses should be counted and 0's elsewhere.
    :param embed_weights: The word embedding matrix, normally the one returned by the featurizer.
    :param hidden: Output of the featurizer.
    :param config: A config object.
    :param reuse: A Flag passed through to the tf.variable_scope context manager.
    :return: A dict containing:
        logits: The un-normalised log-probabilities over each word in the vocabulary.
        loss: The masked language modelling loss.

    """
    X = merge_leading_dims(X, 3)
    M = merge_leading_dims(M, 2)
    hidden = merge_leading_dims(hidden, 3)
    batch, seq, _ = shape_list(X)
    with tf.variable_scope('model/language-model', reuse=reuse):
        # language model ignores last hidden state because we don't have a target
        sliced_hidden = hidden[:, :-1]
        lm_h = tf.reshape(
            sliced_hidden,
            [-1, config.n_embed
             ])  # [batch, seq_len, embed] --> [batch * seq_len, embed]
        lm_logits = tf.matmul(lm_h, embed_weights,
                              transpose_b=True)  # tied weights
        lm_logits = tf.reshape(
            lm_logits,
            [batch, seq - 1, tf.shape(embed_weights)[0]])

        lm_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=lm_logits, labels=X[:, 1:, 0])

        perplexity = tf.reduce_sum(tf.exp(lm_losses) * M[:, 1:],
                                   1) / tf.reduce_sum(M[:, 1:], 1)

        lm_losses = tf.reshape(
            lm_losses,
            [shape_list(X)[0], shape_list(X)[1] - 1])

        # tf.maximum op prevents divide by zero error when mask is all 0s
        lm_losses = tf.reduce_sum(lm_losses * M[:, 1:], 1) / tf.maximum(
            tf.reduce_sum(M[:, 1:], 1), 1)

        lm_logits_shape = shape_list(lm_logits)
        sliced_hidden_shape = shape_list(sliced_hidden)
        return {
            'logits':
            tf.reshape(lm_logits,
                       shape=sliced_hidden_shape[:-1] + [lm_logits_shape[-1]]),
            'losses':
            lm_losses,
            'perplexity':
            perplexity
        }
def dynamic_conv_on_ra_out(ra_out, weights):
    """
    ra_out: batch, length, ra_size, channels
    weights: atch, Time, FilterWidth, Heads
    """
    batch, seq, ra_depth, n_channels = shape_list(ra_out)
    _, _, kernel_width, n_heads = shape_list(weights)
    assert n_channels % n_heads == 0
    assert ra_depth == kernel_width
    h = n_channels // n_heads
    unfolded = tf.reshape(ra_out, [batch, seq, ra_depth, n_heads, h])
    weights_expanded = tf.expand_dims(weights, 4)
    return tf.reshape(tf.reduce_sum(weights_expanded * unfolded, 2),
                      [batch, seq, n_channels])
def linear(inp, output_dim, layer_name):
    with tf.variable_scope(layer_name):
        nx = shape_list(inp)[-1]
        if output_dim is None:
            output_dim = nx
        W = tf.get_variable(name="W",
                            shape=[nx, output_dim],
                            initializer=tf.initializers.glorot_normal())
        if inp.dtype == tf.float16:
            W = tf.cast(W, tf.float16)

        return tf.reshape(
            tf.matmul(tf.reshape(inp, [-1, nx]),
                      tf.reshape(W, [nx, output_dim])),
            shape_list(inp)[:-1] + [output_dim])
Example #6
0
 def mask_attn_weights(w):
     # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
     _, _, nd, ns = shape_list(w)
     b = attention_mask(nd, ns, dtype=w.dtype)
     b = tf.reshape(b, [1, 1, nd, ns])
     w = w * b - tf.cast(1e10, w.dtype) * (1 - b)
     return w
Example #7
0
def block(
    x,
    n_head,
    act_fn,
    adptr_size,
    resid_pdrop,
    attn_pdrop,
    scope,
    train=False,
    scale=False,
    explain=False,
):
    with tf.variable_scope(scope):
        nx = shape_list(x)[-1]
        a = attn(
            x,
            "attn",
            nx,
            n_head,
            resid_pdrop,
            attn_pdrop,
            train=train,
            scale=scale,
            explain=explain,
        )
        if adptr_size is not None:
            with tf.variable_scope("attn_adapter"):
                a = adapter(a, adptr_size, nx, train)
        n = norm(x + a, "ln_1")
        m = mlp(n, "mlp", nx * 4, act_fn, resid_pdrop, train=train)
        if adptr_size is not None:
            with tf.variable_scope("dense_adapter"):
                m = adapter(m, adptr_size, nx, train)
        h = norm(n + m, "ln_2")
        return h
Example #8
0
def normal_1d_conv_block(X, kernel_width, layer_name, use_fp16, dilation=1, output_dim=None, causal=True):
    # layer_input shape = #batch, seq, embed_dim or batch, channels, seq, embed_dim
    with tf.variable_scope(layer_name):
        # Pad kernel_width (word_wise) - 1 to stop future viewing.
        left_pad = (kernel_width - 1) * dilation

        if causal:
            paddings = [[0, 0], [left_pad, 0], [0, 0]]
        else:
            paddings = [[0, 0], [left_pad // 2, left_pad - (left_pad // 2)], [0, 0]]

        if kernel_width > 1:
            padded_input = tf.pad(X, paddings, "CONSTANT")
        else:
            padded_input = X

        nx = shape_list(X)[-1]
        if output_dim is None:
            output_dim = nx
        W = tf.get_variable(name="W", shape=[kernel_width, nx, output_dim], initializer=tf.initializers.glorot_normal())
        b = tf.get_variable(name="B", shape=[output_dim], initializer=tf.initializers.constant(0.0))

        if use_fp16:
            W = tf.cast(W, tf.float16)
            b = tf.cast(b, tf.float16)

        conv = causal_conv(padded_input, W, dilation)
        conv = tf.nn.bias_add(conv, b)
    return conv
def dynamic_conv_cpu(inp, weights, padding="causal"):
    """
    inp : Batch Time Channels
    weights: Batch, Time, FilterWidth, Heads
    padding: causal or same
    """
    batch, seq, n_channels = shape_list(inp)
    _, _, kernel_width, n_heads = shape_list(weights)
    assert n_channels % n_heads == 0
    h = n_channels // n_heads

    unfolded = unfold(inp, kernel_width, padding="causal")
    unfolded = tf.reshape(unfolded, [batch, seq, kernel_size, n_heads, h])
    weights_expanded = tf.expand_dims(weights, 4)
    return tf.reshape(tf.reduce_sum(weights_expanded * unfolded, 2),
                      shape_list(inp))
Example #10
0
def block(x,
          n_head,
          act_fn,
          adptr_size,
          resid_pdrop,
          attn_pdrop,
          scope,
          train=False,
          scale=False,
          explain=False):
    with tf.variable_scope(scope):
        nx = shape_list(x)[-1]
        a = attn(x,
                 'attn',
                 nx,
                 n_head,
                 resid_pdrop,
                 attn_pdrop,
                 train=train,
                 scale=scale,
                 explain=explain)
        if adptr_size is not None:
            with tf.variable_scope('attn_adapter'):
                a = adapter(a, adptr_size, nx, train)
        n = norm(x + a, 'ln_1')
        m = mlp(n, 'mlp', nx * 4, act_fn, resid_pdrop, train=train)
        if adptr_size is not None:
            with tf.variable_scope('dense_adapter'):
                m = adapter(m, adptr_size, nx, train)
        h = norm(n + m, 'ln_2')
        return h
Example #11
0
def block(x, n_head, act_fn, resid_pdrop, attn_pdrop, scope, train=False, scale=False):
    with tf.variable_scope(scope):
        nx = shape_list(x)[-1]
        a = attn(x, 'attn', nx, n_head, resid_pdrop, attn_pdrop, train=train, scale=scale)
        n = norm(x + a, 'ln_1')
        m = mlp(n, 'mlp', nx * 4, act_fn, resid_pdrop, train=train)
        h = norm(n + m, 'ln_2')
        return h
Example #12
0
def mlp(x, scope, n_state, act_fn, resid_pdrop, train=False):
    with tf.variable_scope(scope):
        nx = shape_list(x)[-1]
        act = act_fns[act_fn]
        h = act(conv1d(x, "c_fc", n_state, 1, train=train))
        h2 = conv1d(h, "c_proj", nx, 1, train=train)
        h2 = dropout(h2, resid_pdrop, train)
        return h2
Example #13
0
def mask_pad(w, lengths):
    batch = shape_list(lengths)[0]
    maxlen = tf.cast(tf.reduce_max(lengths), tf.int32)
    seq_mask = tf.reshape(tf.sequence_mask(lengths, maxlen=maxlen),
                          [batch, 1, 1, maxlen])
    b = tf.cast(seq_mask, tf.float32)
    w = w * b + -1e9 * (1 - b)
    return w
Example #14
0
def language_model(*, X, M, embed_weights, hidden, config, reuse=None, train=False):
    """
    A language model output and loss for the language modelling objective described in the original finetune paper.
    This language model uses weights that are tied to the input embedding.
    :param X: The raw token ids fed to the featurizer.
    :param M: A loss mask, with 1's where losses should be counted and 0's elsewhere.
    :param embed_weights: The word embedding matrix, normally the one returned by the featurizer.
    :param hidden: Output of the featurizer.
    :param config: A config object.
    :param reuse: A Flag passed through to the tf.variable_scope context manager.
    :return: A dict containing:
        logits: The un-normalised log-probabilities over each word in the vocabulary.
        loss: The masked language modelling loss.

    """
    X = merge_leading_dims(X, 3)
    M = merge_leading_dims(M, 2)
    hidden = merge_leading_dims(hidden, 3)

    batch, seq, _ = shape_list(X)
    vocab_size, hidden_dim = shape_list(embed_weights)

    with tf.variable_scope('model/language-model', reuse=reuse):
        # language model ignores last hidden state because we don't have a target
        lm_h = tf.reshape(hidden, [-1, config.n_embed])  # [batch, seq_len, embed] --> [batch * seq_len, embed]
        lm_logits = tf.matmul(lm_h, embed_weights, transpose_b=True)  # tied weights
        lm_logits = tf.cast(lm_logits, tf.float32)
        hidden_shape = tf.shape(hidden)
        logits = tf.reshape(lm_logits, shape=tf.concat([hidden_shape[:-1], [vocab_size]], axis=0))
        lm_logits_offset = tf.reshape(logits[:, :-1], [-1, vocab_size])
        
        lm_losses = tf.losses.sparse_softmax_cross_entropy(
            logits=lm_logits_offset,
            labels=tf.reshape(X[:, 1:, 0], [-1]),
            weights=tf.reshape(M[:, 1:], [-1])
        )

        perplexity = tf.reduce_sum(tf.exp(lm_losses) * M[:, 1:], 1) / tf.reduce_sum(M[:, 1:], 1)

        return {
            "logits": logits,
            "losses": lm_losses,
            "perplexity": perplexity,
        }
Example #15
0
def norm(x, scope, axis=[-1], e=1e-5):
    with tf.variable_scope(scope):
        n_state = shape_list(x)[-1]
        g = tf.get_variable("g", [n_state], initializer=tf.constant_initializer(1))
        b = tf.get_variable("b", [n_state], initializer=tf.constant_initializer(0))
        u = tf.reduce_mean(x, axis=axis, keepdims=True)
        s = tf.reduce_mean(tf.square(x - u), axis=axis, keepdims=True)
        x = (x - u) * tf.rsqrt(s + e)
        x = x * g + b
        return x
Example #16
0
def add_explain_tokens(X, max_length, pool_idx):
    flat_x = tf.reshape(X[:, :, :1], [-1, 1])
    flat_pos = tf.minimum(X[:, :, 1:] + 1,
                          max_length - 1)  # + 1 to offset for start token
    clf_tok = tf.gather(
        flat_x,
        tf.range(shape_list(X)[0], dtype=tf.int32) * max_length + pool_idx)
    clf_tok_x_seq = tf.tile(tf.expand_dims(clf_tok, 1), [1, max_length, 1])
    clf_tok_x_seq_w_pos = tf.concat((clf_tok_x_seq, flat_pos), -1)
    return tf.concat((X, clf_tok_x_seq_w_pos), 1)
Example #17
0
def attn_weights(q, k, v, attn_pdrop, train=False, scale=False, mask=True):
    w = tf.matmul(q, k)

    if scale:
        n_state = shape_list(v)[-1]
        w = w * tf.rsqrt(tf.cast(n_state, tf.float32))

    if mask:
        w = mask_attn_weights(w)
    w = tf.nn.softmax(w)
    return w
Example #18
0
def conv1d(x, scope, nf, *, w_init_stdev=0.02):
    with tf.variable_scope(scope):
        *start, nx = shape_list(x)
        w = tf.get_variable(
            'w', [1, nx, nf],
            initializer=tf.random_normal_initializer(stddev=w_init_stdev))
        b = tf.get_variable('b', [nf], initializer=tf.constant_initializer(0))
        c = tf.reshape(
            tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf])) + b,
            start + [nf])
        return c
Example #19
0
def conv1d(x, scope, nf, rf, w_init=tf.random_normal_initializer(stddev=0.02), b_init=tf.constant_initializer(0),
           pad='VALID', train=False):
    with tf.variable_scope(scope):
        nx = shape_list(x)[-1]
        w = tf.get_variable("w", [rf, nx, nf], initializer=w_init)
        b = tf.get_variable("b", [nf], initializer=b_init)
        if rf == 1:  # faster 1x1 conv
            c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf])) + b, shape_list(x)[:-1] + [nf])
        else:  # was used to train LM
            c = tf.nn.conv1d(x, w, stride=1, padding=pad) + b
        return c
def dynamic_convolution_op(inp, weights, padding="causal"):
    """
    inp : Batch Time Channels
    weights: Batch, Time, FilterWidth, Heads
    padding: causal or same
    """
    batch, seq, n_channels = shape_list(inp)
    _, _, kernel_width, n_heads = shape_list(weights)
    assert n_channels % n_heads == 0
    
    if padding.lower() == "causal":
        padding_l = kernel_width - 1
    elif padding.lower() == "same":
        padding_l = kernel_width // 2
    inp_formatted = tf.transpose(inp, [0, 2, 1])
    weights_formatted = tf.transpose(weights, [0, 3, 2, 1])
    return tf.transpose(kernels_module.dynamic_convolution(inp_formatted, weights_formatted, padding_l), [0, 2, 1])
    
    
    
                               
Example #21
0
def _merge_beam_dim(tensor):
    """Reshapes first two dimensions in to single dimension.

    Args:
      tensor: Tensor to reshape of shape [A, B, ...]

    Returns:
      Reshaped tensor of shape [A*B, ...]
    """
    shape = shape_list(tensor)
    shape[0] *= shape[1]  # batch -> batch * beam_size
    shape.pop(1)  # Remove beam dim
    return tf.reshape(tensor, shape)
Example #22
0
def masked_language_model(*, X, M, mlm_weights, mlm_positions, mlm_ids, embed_weights, hidden, config, reuse=None, train=False):
    X = merge_leading_dims(X, 3)
    M = merge_leading_dims(M, 2)
    hidden = merge_leading_dims(hidden, 3)
    batch, seq, _ = shape_list(X)
    with tf.variable_scope('model/masked-language-model'):
        gathered_hidden = gather_indexes(hidden, mlm_positions)
        final_proj = tf.layers.dense(
            gathered_hidden,
            units=config.n_embed,
            activation=act_fns[config.act_fn],
            kernel_initializer=tf.random_normal_initializer(stddev=config.weight_stddev),
            name='dense'
        )
        normed_proj = norm(final_proj, 'LayerNorm')
        n_vocab = shape_list(embed_weights)[0]
        output_bias = tf.get_variable(
            "output_bias",
            shape=[n_vocab],
            initializer=tf.zeros_initializer()
        )
        logits = tf.matmul(normed_proj, embed_weights, transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)
        
        mlm_ids = tf.reshape(mlm_ids, [-1])
        mlm_weights = tf.reshape(mlm_weights, [-1])

        log_probs = tf.nn.log_softmax(logits, axis=-1)
        one_hot_labels = tf.one_hot(mlm_ids, depth=n_vocab, dtype=tf.float32)
        per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
        numerator = tf.reduce_sum(mlm_weights * per_example_loss)
        denominator = tf.reduce_sum(mlm_weights) + 1e-5
        mlm_loss = numerator / denominator

        return {
            "logits": logits,
            "losses": mlm_loss,
        }
Example #23
0
def cumulative_state_net(X, name, use_fp16, pdrop, train, pool_kernel_size=2, nominal_pool_length=512, use_fused_kernel=True):
    conv_kernel = 4
    pool_kernel_size = pool_kernel_size or conv_kernel

    nx = shape_list(X)[-1]
    with tf.variable_scope(name):
        output = tf.nn.relu(normal_1d_conv_block(X, conv_kernel, "1-" + str(conv_kernel), use_fp16, output_dim=nx))
        output = tf.nn.relu(normal_1d_conv_block(output, conv_kernel, "2-" + str(conv_kernel), use_fp16, output_dim=nx))
        output = normal_1d_conv_block(output, conv_kernel, "3-" + str(conv_kernel), use_fp16, output_dim=nx)

    output = dropout(output, pdrop, train)
    aggregated = cascaded_pool(output, kernel_size=pool_kernel_size, pool_len=nominal_pool_length, use_fused_kernel=use_fused_kernel)

    return normal_1d_conv_block(aggregated, 1, "output_reproject", use_fp16, output_dim=nx)
Example #24
0
def cascaded_pool(value, kernel_size, dim=1, pool_len=None, use_fused_kernel=True):
    shape = shape_list(value)
    full_pool_len = pool_len or shape[dim]
    if use_fused_kernel:
        ra = recursive_agg
    else:
        ra = recursive_agg_tf

    aggregated = ra(value, kernel_size, full_pool_len)
    num_pooling_ops = shape_list(aggregated)[2]

    ws = normal_1d_conv_block(
        value, 1, "pool_w", value.dtype == tf.float16, output_dim=shape[-1]
    )
    wt = normal_1d_conv_block(
        value, 1, "pool_t", value.dtype == tf.float16, output_dim=num_pooling_ops
    )
    weights = tf.nn.softmax(wt)
    wt = tf.expand_dims(weights, -1)

    weighted_over_time = tf.reduce_mean(aggregated * wt, 2) * tf.nn.sigmoid(ws)

    return weighted_over_time
Example #25
0
def norm(x, scope, axis=-1, e=None, fp16=False):
    with tf.variable_scope(scope):
        e = e or 1e-5
        n_state = shape_list(x)[-1]
        g = tf.get_variable("g", [n_state], initializer=tf.constant_initializer(1))
        b = tf.get_variable("b", [n_state], initializer=tf.constant_initializer(0))
        if fp16:
            g = tf.cast(g, tf.float16)
            b = tf.cast(b, tf.float16)
        u = tf.reduce_mean(x, axis=axis, keepdims=True)
        s = tf.reduce_mean(tf.square(x - u), axis=axis, keepdims=True)
        x = (x - u) * tf.rsqrt(s + e)
        x = x * g + b
        return x
Example #26
0
def gather_indexes(sequence_tensor, positions):
    """Gathers the vectors at the specific positions over a minibatch."""
    sequence_shape = shape_list(sequence_tensor)
    batch_size = sequence_shape[0]
    seq_length = sequence_shape[1]
    width = sequence_shape[2]

    flat_offsets = tf.reshape(
        tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]
    )
    flat_positions = tf.reshape(positions + flat_offsets, [-1])
    flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width])
    output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
    return output_tensor
Example #27
0
def _unmerge_beam_dim(tensor, batch_size, beam_size):
    """Reshapes first dimension back to [batch_size, beam_size].

    Args:
      tensor: Tensor to reshape of shape [batch_size*beam_size, ...]
      batch_size: Tensor, original batch size.
      beam_size: int, original beam size.

    Returns:
      Reshaped tensor of shape [batch_size, beam_size, ...]
    """
    shape = shape_list(tensor)
    new_shape = [batch_size] + [beam_size] + shape[1:]
    return tf.reshape(tensor, new_shape)
Example #28
0
def attn_weights(q, k, v, scale=False, mask=True, explain=False):
    w = tf.matmul(q, k)

    if scale:
        n_state = shape_list(v)[-1]
        w = w * tf.rsqrt(tf.cast(n_state, tf.float32))

    if mask:
        if explain:
            w = explain_mask_attn_weights(w)
        else:
            w = mask_attn_weights(w)
    w = tf.nn.softmax(w)
    return w
Example #29
0
def enc_dec_mix(enc, dec, enc_mask, dec_mask, n_head=16):
    # enc = batch, seq, feats
    # dec = batch, seq, feats
    with tf.variable_scope("enc_dec_attn"):
        batch, dec_seq, feats = shape_list(dec)
        enc_seq = shape_list(enc)[1]
        
        enc_proj = normal_1d_conv_block(enc, 1, "enc_proj", use_fp16=enc.dtype == tf.float16, output_dim=feats * 2)
        dec_proj = normal_1d_conv_block(dec, 1, "dec_proj", use_fp16=enc.dtype == tf.float16, output_dim=feats)
        k, v = tf.split(enc_proj, 2, 2)
        q = dec_proj
        q = split_heads(q, n_head)
        k = split_heads(k, n_head, k=True)
        v = split_heads(v, n_head)
        w = tf.matmul(q, k)
        w = w * tf.rsqrt(cast_maybe(dec_seq, tf.float32))
        
        enc_mask = tf.reshape(tf.sequence_mask(enc_mask, maxlen=enc_seq, dtype=enc.dtype), [batch, 1, 1, enc_seq])
        dec_mask = tf.reshape(tf.sequence_mask(dec_mask, maxlen=dec_seq, dtype=enc.dtype), [batch, 1, dec_seq, 1])
        m = enc_mask * dec_mask
        w = w * m + -1e9 * (1 - m)
        w = tf.nn.softmax(w)
        
    return merge_heads(tf.matmul(w, v))
Example #30
0
def explain_mask_attn_weights(w):
    # w is [batch, heads, n, n]
    # lengths is [batch]
    batch, _, _, n = shape_list(w)
    seq = n // 2
    main_mask = tf.matrix_band_part(tf.ones([seq, seq]), -1, 0)
    top = tf.expand_dims(tf.concat((main_mask, tf.zeros([seq, seq])), 1),
                         0)  # 1, seq, 2 * seq
    clf_to_clf_mask = tf.eye(seq)
    bottom = tf.expand_dims(tf.concat((main_mask, clf_to_clf_mask), 1),
                            0)  # 1, seq, 2 * seq
    m = tf.concat((top, bottom), 1)
    b = tf.reshape(m, [1, 1, n, n])
    w = w * b + -1e9 * (1 - b)
    return w