def _linear_4d(inputs, output_size, bias, concat=True, data_format="NHWC"): data_format = check_data_format(data_format) channel_axis = 1 if data_format == "NCHW" else -1 input_size = [item.get_shape()[channel_axis].value for item in inputs] outputs = [] if concat: input_size = sum(input_size) inputs = tf.concat(inputs, channel_axis) shape = [input_size, output_size] matrix = tf.get_variable("matrix", shape) matrix = tf.expand_dims(tf.expand_dims(matrix, 0), 1) output = tf.nn.convolution(inputs, matrix, "VALID", data_format=data_format) outputs.append(output) else: for i in range(len(input_size)): shape = [input_size[i], output_size] name = "matrix_%d" % i matrix = tf.get_variable(name, shape) matrix = tf.expand_dims(tf.expand_dims(matrix, 0), 1) output = tf.nn.convolution(inputs, matrix, "VALID", data_format=data_format) outputs.append(output) output = tf.add_n(outputs) if bias is not None: bias = tf.get_variable("bias", [output_size]) output = tf.nn.bias_add(output, bias, data_format=data_format) return output
def layer_norm(inputs, epsilon=1e-6, data_format="NHWC", dtype=None, scope=None): with tf.variable_scope(scope, default_name="layer_norm", values=[inputs], dtype=dtype): data_format = check_data_format(data_format) axis = 1 if data_format == "NCHW" else -1 channel_size = inputs.get_shape().as_list()[axis] scale = tf.get_variable("scale", shape=[channel_size], initializer=tf.ones_initializer()) offset = tf.get_variable("offset", shape=[channel_size], initializer=tf.zeros_initializer()) mean = tf.reduce_mean(inputs, axis=axis, keep_dims=True) variance = tf.reduce_mean(tf.square(inputs - mean), axis=axis, keep_dims=True) norm_inputs = (inputs - mean) * tf.rsqrt(variance + epsilon) return norm_inputs * scale + offset
def multihead_attention(query, memory, bias, key_size, value_size, output_size, num_heads, keep_prob=None, data_format="NHWC", attention_function="dot_product", summaries=False, image_shapes=None, dtype=None, scope=None): """ Multihead scaled-dot-product attention with input/output transformations. Args: query: a Tensor with shape [batch, length_q, channels] if data_format is `NHWC`, [batch, channels, length_q] if data_format is `NCHW` memory: a Tensor with shape [batch, length_m, channels] if data_format is `NHWC`, [batch, channels, length_q] if data_format is `NCHW` bias: bias Tensor (see attention_bias()) key_size: an integer value_size: an integer output_size: an integer num_heads: an integer dividing total_key_depth and total_value_depth keep_prob: a floating point number summaries: a boolean image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() data_format: "NHWC" or "NCHW" attention_function: "dot_product" or "additive" dtype: an optional instance of tf.DType scope: an optional string Returns: A Tensor. """ if key_size % num_heads != 0: raise ValueError("Key size (%d) must be divisible by the number of " "attention heads (%d)." % (key_size, num_heads)) if value_size % num_heads != 0: raise ValueError("Value size (%d) must be divisible by the number of " "attention heads (%d)." % (value_size, num_heads)) with tf.variable_scope(scope, default_name="multihead_attention", values=[query, memory], dtype=dtype): data_format = check_data_format(data_format) axis = 1 if data_format is "NCHW" else 2 if memory is None: # self attention size = key_size * 2 + value_size combined = linear(query, size, True, True, data_format=data_format, scope="qkv_transform") q, k, v = tf.split(combined, [key_size, key_size, value_size], axis=axis) else: q = linear(query, key_size, True, data_format=data_format, scope="q_transform") combined = linear(memory, key_size + value_size, True, data_format=data_format, scope="kv_transform") k, v = tf.split(combined, [key_size, value_size], axis=axis) # split heads q = _split_heads(q, num_heads, data_format=data_format) k = _split_heads(k, num_heads, data_format=data_format) v = _split_heads(v, num_heads, data_format=data_format) # scale query if attention_function == "dot_product": key_depth_per_head = key_size // num_heads q *= key_depth_per_head**-0.5 # attention x = dot_product_attention(q, k, v, bias, keep_prob, summaries, image_shapes) elif attention_function == "additive": x = additive_attention(q, k, v, bias, keep_prob, summaries, image_shapes) else: raise ValueError("Unknown attention function") # combine heads x = _combine_heads(x, data_format=data_format) x = linear(x, output_size, True, data_format=data_format, scope="output_transform") return x