示例#1
0
def weighted_sum(inputs, weights, params, flatten=False):
    assert len(inputs) == len(weights)
    output = tf.add_n([inputs[i] * weights[i] for i in range(len(inputs))])

    weight_ratios = wr.weight_ratio_weighted_sum(inputs, weights, output,
                                                 stab=params.stab,
                                                 flatten=flatten)

    return {"output": output, "weight_ratios": weight_ratios}
示例#2
0
def layer_norm(inputs, w_x_inp, params, epsilon=1e-6, dtype=None, scope=None):
    """
    Layer Normalization
    :param inputs: A Tensor of shape [..., channel_size]
    :param epsilon: A floating number
    :param dtype: An optional instance of tf.DType
    :param scope: An optional string
    :returns: A Tensor with the same shape as inputs

    w_x_inp: [bs, len_src, len, dim]
    """
    with tf.variable_scope(scope,
                           default_name="layer_norm",
                           values=[inputs],
                           dtype=dtype):
        channel_size = inputs.get_shape().as_list()[-1]

        scale = tf.get_variable("scale",
                                shape=[channel_size],
                                initializer=tf.ones_initializer())

        offset = tf.get_variable("offset",
                                 shape=[channel_size],
                                 initializer=tf.zeros_initializer())

        mean = tf.reduce_mean(inputs, axis=-1, keep_dims=True)
        variance = tf.reduce_mean(tf.square(inputs - mean),
                                  axis=-1,
                                  keep_dims=True)

        averaged = (inputs - mean)
        norm_inputs = averaged * tf.rsqrt(variance + epsilon)

        w_inp_mean = wr.weight_ratio_mean(inputs, mean, stab=params.stab)
        w_inp_out, w_mean_out = wr.weight_ratio_weighted_sum([inputs, mean],
                                                             [1., -1.],
                                                             averaged,
                                                             stab=params.stab,
                                                             flatten=True)
        w_x_mean = tf.reduce_sum(w_x_inp * tf.expand_dims(w_inp_mean, 1), -1)
        w_inp_out = tf.expand_dims(w_inp_out, 1)
        w_mean_out = tf.expand_dims(w_mean_out, 1)
        w_x_out = w_x_inp * w_inp_out
        w_x_out += tf.expand_dims(w_x_mean, -1) * w_mean_out

        return {
            "outputs": norm_inputs * scale + offset,
            "weight_ratios": w_x_out
        }
示例#3
0
def residual_fn(x, y, w_x_last, w_x_inp, params, keep_prob=None):
    if keep_prob and keep_prob < 1.0:
        y = tf.nn.dropout(y, keep_prob)
    batchsize = tf.shape(x)[0]
    len_inp = tf.shape(x)[1]
    len_src = tf.shape(w_x_last)[1]
    dim = tf.shape(x)[2]
    result = {}
    result["output"] = x + y
    x_down = tf.reshape(x, [batchsize, -1])
    y_down = tf.reshape(y, [batchsize, -1])
    z_down = tf.reshape(result["output"], [batchsize, -1])

    w_last_out, w_inp_out = wr.weight_ratio_weighted_sum([x_down, y_down],
                                                         [1., 1.],
                                                         z_down,
                                                         stab=params.stab,
                                                         flatten=True)
    # bs, len*d
    w_last_out = tf.reshape(w_last_out, [batchsize, 1, len_inp, dim])
    w_inp_out = tf.reshape(w_inp_out, [batchsize, 1, len_inp, dim])
    w_x_out = w_x_last * w_last_out + w_x_inp * w_inp_out
    result["weight_ratio"] = w_x_out
    return result