def test_pure_lsh_wrapper_non_causal_masked(self, num_weights): with fastmath.use_backend(fastmath.Backend.JAX): n_heads = 5 batch, seqlen, d_head = 3, 32, 8 num_weights = 2 n_hashes = 2 d_model = n_heads * d_head layer = efficient_attention.PureLSHSelfAttentionWrapper( n_heads=n_heads, d_qk=d_head, d_v=d_head, causal=False, masked=True, chunk_len=8, n_chunks_before=1, n_chunks_after=0, n_hashes=n_hashes, n_buckets=4, bias=False, pure_lsh_implementation=efficient_attention. PureLSHSelfAttention, mode='train', num_weights=num_weights) rng = jax.random.PRNGKey(0) rng, x_rng = jax.random.split(rng) input_shape = (batch, seqlen, d_model) x = jax.random.uniform(x_rng, input_shape, dtype=jnp.float32) mask = jnp.ones((batch, seqlen), dtype=jnp.int32) inp = (x, mask) w, s = layer.init(shapes.signature(inp)) o = layer(inp) # Get the actual weights. weights = fastmath.tree_leaves(w) # Assert number of weights is as expected, the extra 1 is for output. self.assertLen(weights, num_weights + 1) # Assert each weight is of the expected shape. for i in range(num_weights + 1): self.assertEqual(weights[i].shape, (d_model, d_model)) # Test that the output and the x's shape match. self.assertEqual(x.shape, o.shape) # Assert state is the shape expected. state = fastmath.tree_leaves(s) self.assertLen(state, 2) # buckets self.assertEqual(state[0].shape, (batch * n_heads, n_hashes * seqlen)) # rngs self.assertEqual(state[1].shape, (batch * n_heads, 2))
def init_weights_and_state(self, input_signature): super().init_weights_and_state(input_signature) if self.init_checkpoint is None: return print('Loading pre-trained weights from', self.init_checkpoint) ckpt = tf.train.load_checkpoint(self.init_checkpoint) def reshape_qkv(name): x = ckpt.get_tensor(name) return x.reshape((x.shape[0], -1, 64)).swapaxes(0, 1) def reshape_o(name): x = ckpt.get_tensor(name) return x.reshape((-1, 64, x.shape[-1])) def reshape_bias(name): x = ckpt.get_tensor(name) return x.reshape((-1, 64)) new_w = [ ckpt.get_tensor('bert/embeddings/word_embeddings'), ckpt.get_tensor('bert/embeddings/token_type_embeddings'), ckpt.get_tensor('bert/embeddings/position_embeddings')[None, ...], ckpt.get_tensor('bert/embeddings/LayerNorm/gamma'), ckpt.get_tensor('bert/embeddings/LayerNorm/beta'), ] for i in range(12): # 12 layers new_w += [ reshape_qkv(f'bert/encoder/layer_{i}/attention/self/query/kernel'), reshape_qkv(f'bert/encoder/layer_{i}/attention/self/key/kernel'), reshape_qkv(f'bert/encoder/layer_{i}/attention/self/value/kernel'), reshape_o(f'bert/encoder/layer_{i}/attention/output/dense/kernel'), reshape_bias(f'bert/encoder/layer_{i}/attention/self/query/bias'), reshape_bias(f'bert/encoder/layer_{i}/attention/self/key/bias'), reshape_bias(f'bert/encoder/layer_{i}/attention/self/value/bias'), ckpt.get_tensor( f'bert/encoder/layer_{i}/attention/output/dense/bias'), ckpt.get_tensor( f'bert/encoder/layer_{i}/attention/output/LayerNorm/gamma'), ckpt.get_tensor( f'bert/encoder/layer_{i}/attention/output/LayerNorm/beta'), ckpt.get_tensor(f'bert/encoder/layer_{i}/intermediate/dense/kernel'), ckpt.get_tensor(f'bert/encoder/layer_{i}/intermediate/dense/bias'), ckpt.get_tensor(f'bert/encoder/layer_{i}/output/dense/kernel'), ckpt.get_tensor(f'bert/encoder/layer_{i}/output/dense/bias'), ckpt.get_tensor(f'bert/encoder/layer_{i}/output/LayerNorm/gamma'), ckpt.get_tensor(f'bert/encoder/layer_{i}/output/LayerNorm/beta'), ] new_w += [ ckpt.get_tensor('bert/pooler/dense/kernel'), ckpt.get_tensor('bert/pooler/dense/bias'), ] for a, b in zip(fastmath.tree_leaves(self.weights), new_w): assert a.shape == b.shape, ( f'Expected shape {a.shape}, got shape {b.shape}') self.weights = jax.tree_unflatten(jax.tree_structure(self.weights), new_w) move_to_device = jax.jit(lambda x: x) self.weights = jax.tree_map(move_to_device, self.weights)