示例#1
0
    def testMultiheadSelfAttentionMemoryEfficient(self):
        if tf.executing_eagerly():
            return  # don't run test in Eager mode

        num_heads = 4
        io_size = 16
        batch = 2
        length = 7
        head_size = 5
        x = np.random.rand(batch, length, io_size)
        dy = np.random.rand(batch, length, io_size)
        with self.session() as session:
            x = tf.to_float(x)
            dy = tf.to_float(dy)
            bias = common_attention.attention_bias_lower_triangle(length)
            wqkv = tf.get_variable(
                "wqkv", [num_heads, 1, io_size, 3 * head_size],
                initializer=tf.random_normal_initializer(stddev=io_size**-0.5))
            wo = tf.get_variable("wo", [num_heads, 1, head_size, io_size],
                                 initializer=tf.random_normal_initializer(
                                     stddev=(head_size * num_heads)**-0.5))
            norm_scale, norm_bias = common_layers.layer_norm_vars(io_size)
            y = common_attention.multihead_self_attention_memory_efficient(
                x,
                bias,
                num_heads,
                head_size=head_size,
                forget=False,
                test_vars=(wqkv, wo, norm_scale, norm_bias))
            y_forget = common_attention.multihead_self_attention_memory_efficient(
                x,
                bias,
                num_heads,
                head_size=head_size,
                forget=True,
                test_vars=(wqkv, wo, norm_scale, norm_bias))
            dx, dwqkv, dwo, dnorm_scale, dnorm_bias = tf.gradients(
                ys=[y], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy])
            dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f = tf.gradients(
                ys=[y_forget],
                xs=[x, wqkv, wo, norm_scale, norm_bias],
                grad_ys=[dy])
            session.run(tf.global_variables_initializer())
            (y, y_forget, dx, dwqkv, dwo, dnorm_scale, dnorm_bias, dx_f,
             dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f) = session.run([
                 y, y_forget, dx, dwqkv, dwo, dnorm_scale, dnorm_bias, dx_f,
                 dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f
             ])
        self.assertAllClose(y, y_forget)
        self.assertAllClose(dwo, dwo_f)
        self.assertAllClose(dwqkv, dwqkv_f)
        self.assertAllClose(dnorm_scale, dnorm_scale_f)
        self.assertAllClose(dnorm_bias, dnorm_bias_f)
        self.assertAllClose(dx, dx_f)
示例#2
0
    def testConvHiddenReluMemoryEfficient(self):
        if tf.executing_eagerly():
            return  # don't run test in Eager mode

        batch = 3
        length = 23
        io_size = 16
        filter_size = 7
        x = np.random.rand(batch, length, io_size)
        dy = np.random.rand(batch, length, io_size)
        with self.session() as session:
            x = tf.to_float(x)
            dy = tf.to_float(dy)
            f1 = tf.get_variable("f1", [1, io_size, filter_size])
            f2 = tf.get_variable("f2", [1, filter_size, io_size])
            norm_scale, norm_bias = common_layers.layer_norm_vars(io_size)
            y = common_layers.conv_hidden_relu_memory_efficient(
                x,
                filter_size,
                forget=False,
                test_vars=(f1, f2, norm_scale, norm_bias))
            y_forget = common_layers.conv_hidden_relu_memory_efficient(
                x,
                filter_size,
                forget=True,
                test_vars=(f1, f2, norm_scale, norm_bias))
            dx, df1, df2, dnorm_scale, dnorm_bias = tf.gradients(
                ys=[y], xs=[x, f1, f2, norm_scale, norm_bias], grad_ys=[dy])
            dx_f, df1_f, df2_f, dnorm_scale_f, dnorm_bias_f = tf.gradients(
                ys=[y_forget],
                xs=[x, f1, f2, norm_scale, norm_bias],
                grad_ys=[dy])
            session.run(tf.global_variables_initializer())
            (y, y_forget, dx, df1, df2, dnorm_scale, dnorm_bias, dx_f, df1_f,
             df2_f, dnorm_scale_f, dnorm_bias_f) = session.run([
                 y, y_forget, dx, df1, df2, dnorm_scale, dnorm_bias, dx_f,
                 df1_f, df2_f, dnorm_scale_f, dnorm_bias_f
             ])
        self.assertAllClose(y, y_forget)
        self.assertAllClose(df2, df2_f)
        self.assertAllClose(df1, df1_f)
        self.assertAllClose(dnorm_scale, dnorm_scale_f)
        self.assertAllClose(dnorm_bias, dnorm_bias_f)
        self.assertAllClose(dx, dx_f)
 def testMultiheadSelfAttentionMemoryEfficient(self):
   num_heads = 4
   io_size = 16
   batch = 2
   length = 7
   head_size = 5
   x = np.random.rand(batch, length, io_size)
   dy = np.random.rand(batch, length, io_size)
   with self.test_session() as session:
     x = tf.to_float(x)
     dy = tf.to_float(dy)
     bias = common_attention.attention_bias_lower_triangle(length)
     wqkv = tf.get_variable(
         "wqkv", [num_heads, 1, io_size, 3 * head_size],
         initializer=tf.random_normal_initializer(stddev=io_size**-0.5))
     wo = tf.get_variable(
         "wo", [num_heads, 1, head_size, io_size],
         initializer=tf.random_normal_initializer(
             stddev=(head_size * num_heads)**-0.5))
     norm_scale, norm_bias = common_layers.layer_norm_vars(io_size)
     y = common_attention.multihead_self_attention_memory_efficient(
         x, bias, num_heads, head_size=head_size, forget=False,
         test_vars=(wqkv, wo, norm_scale, norm_bias))
     y_forget = common_attention.multihead_self_attention_memory_efficient(
         x, bias, num_heads, head_size=head_size, forget=True,
         test_vars=(wqkv, wo, norm_scale, norm_bias))
     dx, dwqkv, dwo, dnorm_scale, dnorm_bias = tf.gradients(
         ys=[y], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy])
     dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f = tf.gradients(
         ys=[y_forget], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy])
     session.run(tf.global_variables_initializer())
     (y, y_forget,
      dx, dwqkv, dwo, dnorm_scale, dnorm_bias,
      dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f) = session.run(
          [y, y_forget,
           dx, dwqkv, dwo, dnorm_scale, dnorm_bias,
           dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f])
   self.assertAllClose(y, y_forget)
   self.assertAllClose(dwo, dwo_f)
   self.assertAllClose(dwqkv, dwqkv_f)
   self.assertAllClose(dnorm_scale, dnorm_scale_f)
   self.assertAllClose(dnorm_bias, dnorm_bias_f)
   self.assertAllClose(dx, dx_f)
 def testConvHiddenReluMemoryEfficient(self):
   batch = 3
   length = 23
   io_size = 16
   filter_size = 7
   x = np.random.rand(batch, length, io_size)
   dy = np.random.rand(batch, length, io_size)
   with self.test_session() as session:
     x = tf.to_float(x)
     dy = tf.to_float(dy)
     f1 = tf.get_variable("f1", [1, io_size, filter_size])
     f2 = tf.get_variable("f2", [1, filter_size, io_size])
     norm_scale, norm_bias = common_layers.layer_norm_vars(io_size)
     y = common_layers.conv_hidden_relu_memory_efficient(
         x, filter_size, forget=False,
         test_vars=(f1, f2, norm_scale, norm_bias))
     y_forget = common_layers.conv_hidden_relu_memory_efficient(
         x, filter_size, forget=True,
         test_vars=(f1, f2, norm_scale, norm_bias))
     dx, df1, df2, dnorm_scale, dnorm_bias = tf.gradients(
         ys=[y], xs=[x, f1, f2, norm_scale, norm_bias], grad_ys=[dy])
     dx_f, df1_f, df2_f, dnorm_scale_f, dnorm_bias_f = tf.gradients(
         ys=[y_forget], xs=[x, f1, f2, norm_scale, norm_bias], grad_ys=[dy])
     session.run(tf.global_variables_initializer())
     (y, y_forget,
      dx, df1, df2, dnorm_scale, dnorm_bias,
      dx_f, df1_f, df2_f, dnorm_scale_f, dnorm_bias_f) = session.run(
          [y, y_forget,
           dx, df1, df2, dnorm_scale, dnorm_bias,
           dx_f, df1_f, df2_f, dnorm_scale_f, dnorm_bias_f])
   self.assertAllClose(y, y_forget)
   self.assertAllClose(df2, df2_f)
   self.assertAllClose(df1, df1_f)
   self.assertAllClose(dnorm_scale, dnorm_scale_f)
   self.assertAllClose(dnorm_bias, dnorm_bias_f)
   self.assertAllClose(dx, dx_f)
示例#5
0
def multihead_self_attention_memory_efficient(x,
                                              bias,
                                              num_heads,
                                              head_size=None,
                                              epsilon=1e-6,
                                              forget=True,
                                              test_vars=None,
                                              name=None):
    """Multihead scaled-dot-product self-attention.

  Includes layer norm.

  Returns multihead-self-attention(layer_norm(x))

  Computes one attention head at a time to avoid exhausting memory.

  If forget=True, then forget all forwards activations and recompute on
  the backwards pass.

  Args:
    x: a Tensor with shape [batch, length, input_size]
    bias: an attention bias tensor broadcastable to [batch, 1, length, length]
    num_heads: an integer
    head_size: an optional integer - defaults to input_size/num_heads
    epsilon: a float, for layer norm
    forget: a boolean - forget forwards activations and recompute on backprop
    test_vars: optional tuple of variables for testing purposes
    name: an optional string

  Returns:
    A Tensor.
  """
    io_size = x.get_shape().as_list()[-1]
    if head_size is None:
        assert io_size % num_heads == 0
        head_size = io_size / num_heads

    def forward_internal(x, wqkv, wo, attention_bias, norm_scale, norm_bias):
        """Forward function."""
        n = common_layers.layer_norm_compute_python(x, epsilon, norm_scale,
                                                    norm_bias)
        wqkv_split = tf.unstack(wqkv, num=num_heads)
        wo_split = tf.unstack(wo, num=num_heads)
        y = 0
        for h in xrange(num_heads):
            with tf.control_dependencies([y] if h > 0 else []):
                combined = tf.nn.conv1d(n, wqkv_split[h], 1, "SAME")
                q, k, v = tf.split(combined, 3, axis=2)
                o = scaled_dot_product_attention_simple(
                    q, k, v, attention_bias)
                y += tf.nn.conv1d(o, wo_split[h], 1, "SAME")
        return y

    key = ("multihead_self_attention_memory_efficient %s %s" %
           (num_heads, epsilon))
    if not forget:
        forward_fn = forward_internal
    elif key in _function_cache:
        forward_fn = _function_cache[key]
    else:

        @function.Defun(compiled=True)
        def grad_fn(x, wqkv, wo, attention_bias, norm_scale, norm_bias, dy):
            with tf.control_dependencies([dy]):
                n = common_layers.layer_norm_compute_python(
                    x, epsilon, norm_scale, norm_bias)
                wqkv_split = tf.unstack(wqkv, num=num_heads)
                wo_split = tf.unstack(wo, num=num_heads)
                deps = []
                dwqkvs = []
                dwos = []
                dn = 0
                for h in xrange(num_heads):
                    with tf.control_dependencies(deps):
                        combined = tf.nn.conv1d(n, wqkv_split[h], 1, "SAME")
                        q, k, v = tf.split(combined, 3, axis=2)
                        o = scaled_dot_product_attention_simple(
                            q, k, v, attention_bias)
                        partial_y = tf.nn.conv1d(o, wo_split[h], 1, "SAME")
                        pdn, dwqkvh, dwoh = tf.gradients(
                            ys=[partial_y],
                            xs=[n, wqkv_split[h], wo_split[h]],
                            grad_ys=[dy])
                        dn += pdn
                        dwqkvs.append(dwqkvh)
                        dwos.append(dwoh)
                        deps = [dn, dwqkvh, dwoh]
                dwqkv = tf.stack(dwqkvs)
                dwo = tf.stack(dwos)
                with tf.control_dependencies(deps):
                    dx, dnorm_scale, dnorm_bias = tf.gradients(
                        ys=[n], xs=[x, norm_scale, norm_bias], grad_ys=[dn])
                return (dx, dwqkv, dwo, tf.zeros_like(attention_bias),
                        dnorm_scale, dnorm_bias)

        @function.Defun(grad_func=grad_fn,
                        compiled=True,
                        separate_compiled_gradients=True)
        def forward_fn(x, wqkv, wo, attention_bias, norm_scale, norm_bias):
            return forward_internal(x, wqkv, wo, attention_bias, norm_scale,
                                    norm_bias)

        _function_cache[key] = forward_fn

    if bias is not None:
        bias = tf.squeeze(bias, 1)
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[x]):
        # TODO(noam): it would be nice to save memory by casting x to float16
        # here, but this causes problems with the gradients.  Figure out if there
        # is a way to leave the gradients as float32.
        if test_vars is not None:
            wqkv, wo, norm_scale, norm_bias = list(test_vars)
        else:
            wqkv = tf.get_variable(
                "wqkv", [num_heads, 1, io_size, 3 * head_size],
                initializer=tf.random_normal_initializer(stddev=io_size**-0.5))
            wo = tf.get_variable("wo", [num_heads, 1, head_size, io_size],
                                 initializer=tf.random_normal_initializer(
                                     stddev=(head_size * num_heads)**-0.5))
            norm_scale, norm_bias = common_layers.layer_norm_vars(io_size)
        y = forward_fn(x, wqkv, wo, bias, norm_scale, norm_bias)
        y.set_shape(x.get_shape())
        return y