Esempio n. 1
0
def double_linear_logits(args,
                         size,
                         bias,
                         bias_start=0.0,
                         scope=None,
                         mask=None,
                         wd=0.0,
                         input_keep_prob=1.0,
                         is_train=None):
    with tf.variable_scope(scope or "Double_Linear_Logits"):
        first = tf.tanh(
            linear(args,
                   size,
                   bias,
                   bias_start=bias_start,
                   scope='first',
                   wd=wd,
                   input_keep_prob=input_keep_prob,
                   is_train=is_train))
        second = linear(first,
                        1,
                        bias,
                        bias_start=bias_start,
                        squeeze=True,
                        scope='second',
                        wd=wd,
                        input_keep_prob=input_keep_prob,
                        is_train=is_train)
        if mask is not None:
            second = exp_mask(second, mask)
        return second
Esempio n. 2
0
def highway_layer(arg, bias, bias_start=0.0, scope=None, wd=0.0, input_keep_prob=1.0, is_train=None):
    with tf.variable_scope(scope or "highway_layer"):
        d = arg.get_shape()[-1]  # embedding dim
        trans = linear([arg], d, bias, bias_start=bias_start, scope='trans', wd=wd, input_keep_prob=input_keep_prob,
                       is_train=is_train)
        trans = tf.nn.relu(trans)
        gate = linear([arg], d, bias, bias_start=bias_start, scope='gate', wd=wd, input_keep_prob=input_keep_prob,
                      is_train=is_train)
        gate = tf.nn.sigmoid(gate)
        out = gate * trans + (1 - gate) * arg
        return out
Esempio n. 3
0
    def __init__(self, cfg, num_modules):
        super().__init__()
        self.cfg = cfg
        self.num_modules = num_modules
        control_dim = cfg.MODEL.KB_DIM
        if cfg.MODEL.CTRL.USE_WORD_EMBED:
            control_dim = cfg.MODEL.EMBED_DIM
        dim = cfg.MODEL.LSTM_DIM

        self.shared_control_proj = linear(dim, dim)
        self.position_aware = nn.ModuleList()
        for i in range(cfg.MODEL.T_CTRL):
            self.position_aware.append(linear(dim, dim))

        self.control_question = linear(dim + control_dim, dim)
        self.attn = linear(dim, 1)

        if self.cfg.MODEL.CTRL.LINEAR_MODULE_WEIGHTS:
            self.module_fc = nn.Linear(dim, num_modules, bias=False)
        else:
            self.module_fc = nn.Sequential(
                nn.Linear(dim, cfg.MODEL.LSTM_DIM), nn.ELU(),
                nn.Linear(cfg.MODEL.LSTM_DIM, num_modules))
Esempio n. 4
0
def linear_logits(args,
                  bias,
                  bias_start=0.0,
                  scope=None,
                  mask=None,
                  wd=0.0,
                  input_keep_prob=1.0,
                  is_train=None):
    with tf.variable_scope(scope or "Linear_Logits"):
        logits = linear(args,
                        1,
                        bias,
                        bias_start=bias_start,
                        squeeze=True,
                        scope='first',
                        wd=wd,
                        input_keep_prob=input_keep_prob,
                        is_train=is_train)
        if mask is not None:
            logits = exp_mask(logits, mask)
        return logits
Esempio n. 5
0
def get_logits(args,
               size,
               bias,
               bias_start=0.0,
               scope=None,
               mask=None,
               wd=0.0,
               input_keep_prob=1.0,
               is_train=None,
               func=None):
    if func is None:
        func = "linear"
    if func == 'sum':
        return sum_logits(args, mask=mask, name=scope)
    elif func == 'linear':
        return linear_logits(args,
                             bias,
                             bias_start=bias_start,
                             scope=scope,
                             mask=mask,
                             wd=wd,
                             input_keep_prob=input_keep_prob,
                             is_train=is_train)
    elif func == 'double':
        return double_linear_logits(args,
                                    size,
                                    bias,
                                    bias_start=bias_start,
                                    scope=scope,
                                    mask=mask,
                                    wd=wd,
                                    input_keep_prob=input_keep_prob,
                                    is_train=is_train)
    elif func == 'dot':
        assert len(args) == 2
        arg = args[0] * args[1]
        return sum_logits([arg], mask=mask, name=scope)
    elif func == 'mul_linear':
        assert len(args) == 2
        arg = args[0] * args[1]
        return linear_logits([arg],
                             bias,
                             bias_start=bias_start,
                             scope=scope,
                             mask=mask,
                             wd=wd,
                             input_keep_prob=input_keep_prob,
                             is_train=is_train)
    elif func == 'proj':
        assert len(args) == 2
        d = args[1].get_shape()[-1]
        proj = linear([args[0]],
                      d,
                      False,
                      bias_start=bias_start,
                      scope=scope,
                      wd=wd,
                      input_keep_prob=input_keep_prob,
                      is_train=is_train)
        return sum_logits([proj * args[1]], mask=mask)
    elif func == 'tri_linear':
        assert len(args) == 2
        new_arg = args[0] * args[1]
        return linear_logits([args[0], args[1], new_arg],
                             bias,
                             bias_start=bias_start,
                             scope=scope,
                             mask=mask,
                             wd=wd,
                             input_keep_prob=input_keep_prob,
                             is_train=is_train)
    else:
        raise Exception()
Esempio n. 6
0
def directional_attention_with_dense(rep_tensor, rep_mask, direction=None, scope=None,
                                     keep_prob=1., is_train=None, wd=0., activation='elu',
                                     tensor_dict=None, name=None):
    def scaled_tanh(x, scale=5.):
        return scale * tf.nn.tanh(1./scale * x)

    bs, sl, vec = tf.shape(rep_tensor)[0], tf.shape(rep_tensor)[1], tf.shape(rep_tensor)[2]
    ivec = rep_tensor.get_shape()[2]
    with tf.variable_scope(scope or 'directional_attention_%s' % direction or 'diag'):
        # mask generation
        sl_indices = tf.range(sl, dtype=tf.int32)
        sl_col, sl_row = tf.meshgrid(sl_indices, sl_indices)
        if direction is None:
            direct_mask = tf.cast(tf.diag(- tf.ones([sl], tf.int32)) + 1, tf.bool)
        else:
            if direction == 'forward':
                direct_mask = tf.greater(sl_row, sl_col)
            else:
                direct_mask = tf.greater(sl_col, sl_row)
        direct_mask_tile = tf.tile(tf.expand_dims(direct_mask, 0), [bs, 1, 1])  # bs,sl,sl
        rep_mask_tile = tf.tile(tf.expand_dims(rep_mask, 1), [1, sl, 1])  # bs,sl,sl
        attn_mask = tf.logical_and(direct_mask_tile, rep_mask_tile)  # bs,sl,sl

        # non-linear
        rep_map = bn_dense_layer(rep_tensor, ivec, True, 0., 'bn_dense_map', activation,
                                 False, wd, keep_prob, is_train)
        rep_map_tile = tf.tile(tf.expand_dims(rep_map, 1), [1, sl, 1, 1])  # bs,sl,sl,vec
        rep_map_dp = dropout(rep_map, keep_prob, is_train)

        # attention
        with tf.variable_scope('attention'):  # bs,sl,sl,vec
            f_bias = tf.get_variable('f_bias',[ivec], tf.float32, tf.constant_initializer(0.))
            dependent = linear(rep_map_dp, ivec, False, scope='linear_dependent')  # bs,sl,vec
            dependent_etd = tf.expand_dims(dependent, 1)  # bs,1,sl,vec
            head = linear(rep_map_dp, ivec, False, scope='linear_head') # bs,sl,vec
            head_etd = tf.expand_dims(head, 2)  # bs,sl,1,vec

            logits = scaled_tanh(dependent_etd + head_etd + f_bias, 5.0)  # bs,sl,sl,vec

            logits_masked = exp_mask_for_high_rank(logits, attn_mask)
            attn_score = tf.nn.softmax(logits_masked, 2)  # bs,sl,sl,vec
            attn_score = mask_for_high_rank(attn_score, attn_mask)

            attn_result = tf.reduce_sum(attn_score * rep_map_tile, 2)  # bs,sl,vec

        with tf.variable_scope('output'):
            o_bias = tf.get_variable('o_bias',[ivec], tf.float32, tf.constant_initializer(0.))
            # input gate
            fusion_gate = tf.nn.sigmoid(
                linear(rep_map, ivec, True, 0., 'linear_fusion_i', False, wd, keep_prob, is_train) +
                linear(attn_result, ivec, True, 0., 'linear_fusion_a', False, wd, keep_prob, is_train) +
                o_bias)
            output = fusion_gate * rep_map + (1-fusion_gate) * attn_result
            output = mask_for_high_rank(output, rep_mask)

        # save attn
        if tensor_dict is not None and name is not None:
            tensor_dict[name + '_dependent'] = dependent
            tensor_dict[name + '_head'] = head
            tensor_dict[name] = attn_score
            tensor_dict[name + '_gate'] = fusion_gate
        return output
Esempio n. 7
0
def multihead_attention(query,
                        memory,
                        bias,
                        key_size,
                        value_size,
                        output_size,
                        num_heads,
                        keep_prob=None,
                        data_format="NHWC",
                        attention_function="dot_product",
                        dtype=None,
                        scope=None):
    """ Multihead scaled-dot-product attention with input/output
        transformations.

    Args:
        query: a Tensor with shape [batch, length_q, channels] if
            data_format is `NHWC`, [batch, channels, length_q] if
            data_format is `NCHW`
        memory: a Tensor with shape [batch, length_m, channels] if
            data_format is `NHWC`, [batch, channels, length_q] if
            data_format is `NCHW`
        bias: bias Tensor (see attention_bias())
        key_size: an integer
        value_size: an integer
        output_size: an integer
        num_heads: an integer dividing total_key_depth and total_value_depth
        keep_prob: a floating point number
        summaries: a boolean
        image_shapes: optional tuple of integer scalars.
            see comments for attention_image_summary()
        data_format: "NHWC" or "NCHW"
        attention_function: "dot_product" or "additive"
        dtype: an optional instance of tf.DType
        scope: an optional string

    Returns:
        A Tensor.
    """
    if key_size % num_heads != 0:
        raise ValueError("Key size (%d) must be divisible by the number of "
                         "attention heads (%d)." % (key_size, num_heads))

    if value_size % num_heads != 0:
        raise ValueError("Value size (%d) must be divisible by the number of "
                         "attention heads (%d)." % (value_size, num_heads))

    with tf.variable_scope(scope,
                           default_name="multihead_attention",
                           values=[query, memory],
                           dtype=dtype):

        axis = 2

        if memory is None:
            # self attention
            size = key_size * 2 + value_size
            combined = linear(query,
                              size,
                              True,
                              True,
                              data_format=data_format,
                              scope="qkv_transform")
            q, k, v = tf.split(combined, [key_size, key_size, value_size],
                               axis=axis)
        else:
            q = linear(query,
                       key_size,
                       True,
                       data_format=data_format,
                       scope="q_transform")
            combined = linear(memory,
                              key_size + value_size,
                              True,
                              data_format=data_format,
                              scope="kv_transform")
            k, v = tf.split(combined, [key_size, value_size], axis=axis)

        # split heads
        q = _split_heads(q, num_heads)
        k = _split_heads(k, num_heads)
        v = _split_heads(v, num_heads)

        # scale query
        if attention_function == "dot_product":
            key_depth_per_head = key_size // num_heads
            q *= key_depth_per_head**-0.5

            # attention
            x = dot_product_attention(q, k, v, bias, keep_prob)

        # combine heads
        x = _combine_heads(x)

        x = linear(x,
                   output_size,
                   True,
                   data_format=data_format,
                   scope="output_transform")

        return x