def testMultiheadSelfAttentionMemoryEfficient(self): if tf.executing_eagerly(): return # don't run test in Eager mode num_heads = 4 io_size = 16 batch = 2 length = 7 head_size = 5 x = np.random.rand(batch, length, io_size) dy = np.random.rand(batch, length, io_size) with self.session() as session: x = tf.to_float(x) dy = tf.to_float(dy) bias = common_attention.attention_bias_lower_triangle(length) wqkv = tf.get_variable( "wqkv", [num_heads, 1, io_size, 3 * head_size], initializer=tf.random_normal_initializer(stddev=io_size**-0.5)) wo = tf.get_variable("wo", [num_heads, 1, head_size, io_size], initializer=tf.random_normal_initializer( stddev=(head_size * num_heads)**-0.5)) norm_scale, norm_bias = common_layers.layer_norm_vars(io_size) y = common_attention.multihead_self_attention_memory_efficient( x, bias, num_heads, head_size=head_size, forget=False, test_vars=(wqkv, wo, norm_scale, norm_bias)) y_forget = common_attention.multihead_self_attention_memory_efficient( x, bias, num_heads, head_size=head_size, forget=True, test_vars=(wqkv, wo, norm_scale, norm_bias)) dx, dwqkv, dwo, dnorm_scale, dnorm_bias = tf.gradients( ys=[y], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy]) dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f = tf.gradients( ys=[y_forget], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy]) session.run(tf.global_variables_initializer()) (y, y_forget, dx, dwqkv, dwo, dnorm_scale, dnorm_bias, dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f) = session.run([ y, y_forget, dx, dwqkv, dwo, dnorm_scale, dnorm_bias, dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f ]) self.assertAllClose(y, y_forget) self.assertAllClose(dwo, dwo_f) self.assertAllClose(dwqkv, dwqkv_f) self.assertAllClose(dnorm_scale, dnorm_scale_f) self.assertAllClose(dnorm_bias, dnorm_bias_f) self.assertAllClose(dx, dx_f)
def testConvHiddenReluMemoryEfficient(self): if tf.executing_eagerly(): return # don't run test in Eager mode batch = 3 length = 23 io_size = 16 filter_size = 7 x = np.random.rand(batch, length, io_size) dy = np.random.rand(batch, length, io_size) with self.session() as session: x = tf.to_float(x) dy = tf.to_float(dy) f1 = tf.get_variable("f1", [1, io_size, filter_size]) f2 = tf.get_variable("f2", [1, filter_size, io_size]) norm_scale, norm_bias = common_layers.layer_norm_vars(io_size) y = common_layers.conv_hidden_relu_memory_efficient( x, filter_size, forget=False, test_vars=(f1, f2, norm_scale, norm_bias)) y_forget = common_layers.conv_hidden_relu_memory_efficient( x, filter_size, forget=True, test_vars=(f1, f2, norm_scale, norm_bias)) dx, df1, df2, dnorm_scale, dnorm_bias = tf.gradients( ys=[y], xs=[x, f1, f2, norm_scale, norm_bias], grad_ys=[dy]) dx_f, df1_f, df2_f, dnorm_scale_f, dnorm_bias_f = tf.gradients( ys=[y_forget], xs=[x, f1, f2, norm_scale, norm_bias], grad_ys=[dy]) session.run(tf.global_variables_initializer()) (y, y_forget, dx, df1, df2, dnorm_scale, dnorm_bias, dx_f, df1_f, df2_f, dnorm_scale_f, dnorm_bias_f) = session.run([ y, y_forget, dx, df1, df2, dnorm_scale, dnorm_bias, dx_f, df1_f, df2_f, dnorm_scale_f, dnorm_bias_f ]) self.assertAllClose(y, y_forget) self.assertAllClose(df2, df2_f) self.assertAllClose(df1, df1_f) self.assertAllClose(dnorm_scale, dnorm_scale_f) self.assertAllClose(dnorm_bias, dnorm_bias_f) self.assertAllClose(dx, dx_f)
def testMultiheadSelfAttentionMemoryEfficient(self): num_heads = 4 io_size = 16 batch = 2 length = 7 head_size = 5 x = np.random.rand(batch, length, io_size) dy = np.random.rand(batch, length, io_size) with self.test_session() as session: x = tf.to_float(x) dy = tf.to_float(dy) bias = common_attention.attention_bias_lower_triangle(length) wqkv = tf.get_variable( "wqkv", [num_heads, 1, io_size, 3 * head_size], initializer=tf.random_normal_initializer(stddev=io_size**-0.5)) wo = tf.get_variable( "wo", [num_heads, 1, head_size, io_size], initializer=tf.random_normal_initializer( stddev=(head_size * num_heads)**-0.5)) norm_scale, norm_bias = common_layers.layer_norm_vars(io_size) y = common_attention.multihead_self_attention_memory_efficient( x, bias, num_heads, head_size=head_size, forget=False, test_vars=(wqkv, wo, norm_scale, norm_bias)) y_forget = common_attention.multihead_self_attention_memory_efficient( x, bias, num_heads, head_size=head_size, forget=True, test_vars=(wqkv, wo, norm_scale, norm_bias)) dx, dwqkv, dwo, dnorm_scale, dnorm_bias = tf.gradients( ys=[y], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy]) dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f = tf.gradients( ys=[y_forget], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy]) session.run(tf.global_variables_initializer()) (y, y_forget, dx, dwqkv, dwo, dnorm_scale, dnorm_bias, dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f) = session.run( [y, y_forget, dx, dwqkv, dwo, dnorm_scale, dnorm_bias, dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f]) self.assertAllClose(y, y_forget) self.assertAllClose(dwo, dwo_f) self.assertAllClose(dwqkv, dwqkv_f) self.assertAllClose(dnorm_scale, dnorm_scale_f) self.assertAllClose(dnorm_bias, dnorm_bias_f) self.assertAllClose(dx, dx_f)
def testConvHiddenReluMemoryEfficient(self): batch = 3 length = 23 io_size = 16 filter_size = 7 x = np.random.rand(batch, length, io_size) dy = np.random.rand(batch, length, io_size) with self.test_session() as session: x = tf.to_float(x) dy = tf.to_float(dy) f1 = tf.get_variable("f1", [1, io_size, filter_size]) f2 = tf.get_variable("f2", [1, filter_size, io_size]) norm_scale, norm_bias = common_layers.layer_norm_vars(io_size) y = common_layers.conv_hidden_relu_memory_efficient( x, filter_size, forget=False, test_vars=(f1, f2, norm_scale, norm_bias)) y_forget = common_layers.conv_hidden_relu_memory_efficient( x, filter_size, forget=True, test_vars=(f1, f2, norm_scale, norm_bias)) dx, df1, df2, dnorm_scale, dnorm_bias = tf.gradients( ys=[y], xs=[x, f1, f2, norm_scale, norm_bias], grad_ys=[dy]) dx_f, df1_f, df2_f, dnorm_scale_f, dnorm_bias_f = tf.gradients( ys=[y_forget], xs=[x, f1, f2, norm_scale, norm_bias], grad_ys=[dy]) session.run(tf.global_variables_initializer()) (y, y_forget, dx, df1, df2, dnorm_scale, dnorm_bias, dx_f, df1_f, df2_f, dnorm_scale_f, dnorm_bias_f) = session.run( [y, y_forget, dx, df1, df2, dnorm_scale, dnorm_bias, dx_f, df1_f, df2_f, dnorm_scale_f, dnorm_bias_f]) self.assertAllClose(y, y_forget) self.assertAllClose(df2, df2_f) self.assertAllClose(df1, df1_f) self.assertAllClose(dnorm_scale, dnorm_scale_f) self.assertAllClose(dnorm_bias, dnorm_bias_f) self.assertAllClose(dx, dx_f)
def multihead_self_attention_memory_efficient(x, bias, num_heads, head_size=None, epsilon=1e-6, forget=True, test_vars=None, name=None): """Multihead scaled-dot-product self-attention. Includes layer norm. Returns multihead-self-attention(layer_norm(x)) Computes one attention head at a time to avoid exhausting memory. If forget=True, then forget all forwards activations and recompute on the backwards pass. Args: x: a Tensor with shape [batch, length, input_size] bias: an attention bias tensor broadcastable to [batch, 1, length, length] num_heads: an integer head_size: an optional integer - defaults to input_size/num_heads epsilon: a float, for layer norm forget: a boolean - forget forwards activations and recompute on backprop test_vars: optional tuple of variables for testing purposes name: an optional string Returns: A Tensor. """ io_size = x.get_shape().as_list()[-1] if head_size is None: assert io_size % num_heads == 0 head_size = io_size / num_heads def forward_internal(x, wqkv, wo, attention_bias, norm_scale, norm_bias): """Forward function.""" n = common_layers.layer_norm_compute_python(x, epsilon, norm_scale, norm_bias) wqkv_split = tf.unstack(wqkv, num=num_heads) wo_split = tf.unstack(wo, num=num_heads) y = 0 for h in xrange(num_heads): with tf.control_dependencies([y] if h > 0 else []): combined = tf.nn.conv1d(n, wqkv_split[h], 1, "SAME") q, k, v = tf.split(combined, 3, axis=2) o = scaled_dot_product_attention_simple( q, k, v, attention_bias) y += tf.nn.conv1d(o, wo_split[h], 1, "SAME") return y key = ("multihead_self_attention_memory_efficient %s %s" % (num_heads, epsilon)) if not forget: forward_fn = forward_internal elif key in _function_cache: forward_fn = _function_cache[key] else: @function.Defun(compiled=True) def grad_fn(x, wqkv, wo, attention_bias, norm_scale, norm_bias, dy): with tf.control_dependencies([dy]): n = common_layers.layer_norm_compute_python( x, epsilon, norm_scale, norm_bias) wqkv_split = tf.unstack(wqkv, num=num_heads) wo_split = tf.unstack(wo, num=num_heads) deps = [] dwqkvs = [] dwos = [] dn = 0 for h in xrange(num_heads): with tf.control_dependencies(deps): combined = tf.nn.conv1d(n, wqkv_split[h], 1, "SAME") q, k, v = tf.split(combined, 3, axis=2) o = scaled_dot_product_attention_simple( q, k, v, attention_bias) partial_y = tf.nn.conv1d(o, wo_split[h], 1, "SAME") pdn, dwqkvh, dwoh = tf.gradients( ys=[partial_y], xs=[n, wqkv_split[h], wo_split[h]], grad_ys=[dy]) dn += pdn dwqkvs.append(dwqkvh) dwos.append(dwoh) deps = [dn, dwqkvh, dwoh] dwqkv = tf.stack(dwqkvs) dwo = tf.stack(dwos) with tf.control_dependencies(deps): dx, dnorm_scale, dnorm_bias = tf.gradients( ys=[n], xs=[x, norm_scale, norm_bias], grad_ys=[dn]) return (dx, dwqkv, dwo, tf.zeros_like(attention_bias), dnorm_scale, dnorm_bias) @function.Defun(grad_func=grad_fn, compiled=True, separate_compiled_gradients=True) def forward_fn(x, wqkv, wo, attention_bias, norm_scale, norm_bias): return forward_internal(x, wqkv, wo, attention_bias, norm_scale, norm_bias) _function_cache[key] = forward_fn if bias is not None: bias = tf.squeeze(bias, 1) with tf.variable_scope(name, default_name="multihead_attention", values=[x]): # TODO(noam): it would be nice to save memory by casting x to float16 # here, but this causes problems with the gradients. Figure out if there # is a way to leave the gradients as float32. if test_vars is not None: wqkv, wo, norm_scale, norm_bias = list(test_vars) else: wqkv = tf.get_variable( "wqkv", [num_heads, 1, io_size, 3 * head_size], initializer=tf.random_normal_initializer(stddev=io_size**-0.5)) wo = tf.get_variable("wo", [num_heads, 1, head_size, io_size], initializer=tf.random_normal_initializer( stddev=(head_size * num_heads)**-0.5)) norm_scale, norm_bias = common_layers.layer_norm_vars(io_size) y = forward_fn(x, wqkv, wo, bias, norm_scale, norm_bias) y.set_shape(x.get_shape()) return y