Пример #1
0
    def test_pure_lsh_wrapper_non_causal_masked(self, num_weights):
        with fastmath.use_backend(fastmath.Backend.JAX):
            n_heads = 5
            batch, seqlen, d_head = 3, 32, 8
            num_weights = 2
            n_hashes = 2
            d_model = n_heads * d_head
            layer = efficient_attention.PureLSHSelfAttentionWrapper(
                n_heads=n_heads,
                d_qk=d_head,
                d_v=d_head,
                causal=False,
                masked=True,
                chunk_len=8,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=n_hashes,
                n_buckets=4,
                bias=False,
                pure_lsh_implementation=efficient_attention.
                PureLSHSelfAttention,
                mode='train',
                num_weights=num_weights)

            rng = jax.random.PRNGKey(0)
            rng, x_rng = jax.random.split(rng)

            input_shape = (batch, seqlen, d_model)
            x = jax.random.uniform(x_rng, input_shape, dtype=jnp.float32)
            mask = jnp.ones((batch, seqlen), dtype=jnp.int32)

            inp = (x, mask)
            w, s = layer.init(shapes.signature(inp))
            o = layer(inp)

            # Get the actual weights.
            weights = fastmath.tree_leaves(w)
            # Assert number of weights is as expected, the extra 1 is for output.
            self.assertLen(weights, num_weights + 1)

            # Assert each weight is of the expected shape.
            for i in range(num_weights + 1):
                self.assertEqual(weights[i].shape, (d_model, d_model))

            # Test that the output and the x's shape match.
            self.assertEqual(x.shape, o.shape)

            # Assert state is the shape expected.
            state = fastmath.tree_leaves(s)
            self.assertLen(state, 2)
            # buckets
            self.assertEqual(state[0].shape,
                             (batch * n_heads, n_hashes * seqlen))
            # rngs
            self.assertEqual(state[1].shape, (batch * n_heads, 2))
Пример #2
0
  def init_weights_and_state(self, input_signature):
    super().init_weights_and_state(input_signature)
    if self.init_checkpoint is None:
      return

    print('Loading pre-trained weights from', self.init_checkpoint)
    ckpt = tf.train.load_checkpoint(self.init_checkpoint)

    def reshape_qkv(name):
      x = ckpt.get_tensor(name)
      return x.reshape((x.shape[0], -1, 64)).swapaxes(0, 1)
    def reshape_o(name):
      x = ckpt.get_tensor(name)
      return x.reshape((-1, 64, x.shape[-1]))
    def reshape_bias(name):
      x = ckpt.get_tensor(name)
      return x.reshape((-1, 64))

    new_w = [
        ckpt.get_tensor('bert/embeddings/word_embeddings'),
        ckpt.get_tensor('bert/embeddings/token_type_embeddings'),
        ckpt.get_tensor('bert/embeddings/position_embeddings')[None, ...],
        ckpt.get_tensor('bert/embeddings/LayerNorm/gamma'),
        ckpt.get_tensor('bert/embeddings/LayerNorm/beta'),
    ]

    for i in range(12):  # 12 layers
      new_w += [
          reshape_qkv(f'bert/encoder/layer_{i}/attention/self/query/kernel'),
          reshape_qkv(f'bert/encoder/layer_{i}/attention/self/key/kernel'),
          reshape_qkv(f'bert/encoder/layer_{i}/attention/self/value/kernel'),
          reshape_o(f'bert/encoder/layer_{i}/attention/output/dense/kernel'),
          reshape_bias(f'bert/encoder/layer_{i}/attention/self/query/bias'),
          reshape_bias(f'bert/encoder/layer_{i}/attention/self/key/bias'),
          reshape_bias(f'bert/encoder/layer_{i}/attention/self/value/bias'),
          ckpt.get_tensor(
              f'bert/encoder/layer_{i}/attention/output/dense/bias'),
          ckpt.get_tensor(
              f'bert/encoder/layer_{i}/attention/output/LayerNorm/gamma'),
          ckpt.get_tensor(
              f'bert/encoder/layer_{i}/attention/output/LayerNorm/beta'),
          ckpt.get_tensor(f'bert/encoder/layer_{i}/intermediate/dense/kernel'),
          ckpt.get_tensor(f'bert/encoder/layer_{i}/intermediate/dense/bias'),
          ckpt.get_tensor(f'bert/encoder/layer_{i}/output/dense/kernel'),
          ckpt.get_tensor(f'bert/encoder/layer_{i}/output/dense/bias'),
          ckpt.get_tensor(f'bert/encoder/layer_{i}/output/LayerNorm/gamma'),
          ckpt.get_tensor(f'bert/encoder/layer_{i}/output/LayerNorm/beta'),
      ]

    new_w += [
        ckpt.get_tensor('bert/pooler/dense/kernel'),
        ckpt.get_tensor('bert/pooler/dense/bias'),
    ]

    for a, b in zip(fastmath.tree_leaves(self.weights), new_w):
      assert a.shape == b.shape, (
          f'Expected shape {a.shape}, got shape {b.shape}')
    self.weights = jax.tree_unflatten(jax.tree_structure(self.weights), new_w)
    move_to_device = jax.jit(lambda x: x)
    self.weights = jax.tree_map(move_to_device, self.weights)