def _test_fast_inference(self, model_cls, x, input_signature,
                             common_kwargs, *test_kwargs):
        ref_model = model_cls(use_reference_code=True,
                              mode='eval',
                              **common_kwargs)
        weights, state = ref_model.init(input_signature)

        ref_out, _ = ref_model.pure_fn(x,
                                       weights,
                                       state,
                                       rng=jax.random.PRNGKey(0))

        def get_slice(pytree, i):
            def get_slice_for_val(x):
                if isinstance(x, shapes.ShapeDtype):
                    return shapes.ShapeDtype(shape=x.shape[:1] + (1, ) +
                                             x.shape[2:],
                                             dtype=x.dtype)
                else:
                    return x[:, i:i + 1]

            return jax.tree_map(get_slice_for_val, pytree)

        seqlen = x[0].shape[1] if isinstance(x, (tuple, list)) else x.shape[1]

        for kwargs in test_kwargs:
            test_model = model_cls(mode='predict', **common_kwargs, **kwargs)
            cur_state = test_model.init(get_slice(input_signature, 0))[1]
            out = []
            for i in range(seqlen):
                cur_out, cur_state = test_model.pure_fn(
                    get_slice(x, i), weights, cur_state, jax.random.PRNGKey(0))
                out.append(cur_out)
            out = jnp.concatenate(out, axis=1)

            self.assertAllClose(out, ref_out, rtol=1e-3, atol=1e-3)
Esempio n. 2
0
 def forward(self, x):
     rng = self.rng
     batch_size, length = x.shape[0], x.shape[1]
     max_pos = min(self._bases)**self._n_digits
     rng1, rng2, rng3 = fastmath.random.split(rng, 3)
     assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length,
                                                               max_pos)
     positions = jnp.arange(0, length)[None, :]
     if self._mode == 'train':
         # In 1% of training cases still start from 0 to be exactly as in eval.
         start_from_nonzero = jax.random.randint(
             rng1, (batch_size, ), 0, self._start_from_zero_one_in)
         start_from_nonzero = jnp.minimum(1, start_from_nonzero)
         random_start = jax.random.randint(rng2, (batch_size, ), 0,
                                           max_pos - length)
         random_start *= start_from_nonzero
         positions += random_start[:, None]
     res = []
     for bn, base in enumerate(self._bases):
         pos_embeddings = []
         cur_positions = positions
         for i in range(self._n_digits):
             cur_indices = jnp.mod(cur_positions, base)
             cur_positions = cur_positions // base
             s = self.weights[bn][i]
             pos_embeddings.append(
                 cur_indices.astype(jnp.float32)[:, :, None] * s)
         embeddings = jnp.concatenate(pos_embeddings, axis=-1)
         if self._mode == 'train':
             base_dropout = jax.random.randint(rng3, (batch_size, ), 0,
                                               self._base_dropout_one_in)
             base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32)
             embeddings *= base_dropout[:, None, None]
         res.append(embeddings)
     res = sum(res) + jnp.zeros_like(x)
     return x + res
Esempio n. 3
0
def _fast_matrix_shift(x, funnel_factor, is_upsampling=False):
  """Fast matrix shift.

  Implements necessary shift for relative positional attention calculations.
  Based on funnel_factor and information whether we perform upsampling
  or downsampling it calculates necessary shift and interval at which
  we pick correct values for attention.

  Args:
    x: matrix.
    funnel_factor: factor to be used for shift.
    is_upsampling: determines whether perform upsampling.

  Returns:
    Shifted matrix x.
  """
  #  shift: i-th row is shifted by i * shift elements to the left
  #  k: after shift, we pick every kth element

  if is_upsampling:
    k = funnel_factor
    shift = 1
  else:
    k = 1
    shift = funnel_factor

  bsz, n_head = x.shape[0], x.shape[1]
  qlen, klen = x.shape[2], (x.shape[3] + 1) // 2

  zero_pad = jnp.zeros((bsz, n_head, qlen, shift))
  x = jnp.concatenate([zero_pad, x], axis=3)
  x = x.reshape(bsz, n_head, 2 * klen - 1 + shift, qlen)
  x = x[:, :, shift:, :]
  x = x.reshape(bsz, n_head, qlen, klen * 2 - 1)
  x = x[:, :, :, shift - 1:shift - 1 + klen:k]
  return x
Esempio n. 4
0
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
          initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input, except the final dimension
      is the layer's `filters` value, and the second to last dimension is
      shrinked if 'VALID' padding is used with kernel_size bigger than one.
    """
        if self._use_bias:
            if not isinstance(self.weights, (tuple, list)):
                raise ValueError(f'Weights should be a (w, b) tuple or list; '
                                 f'instead got: {self.weights}')
            w, b = self.weights
        else:
            w = self.weights

        linear_results_before_shifting = jnp.einsum('...lp,lkpd->...lkd', x, w)
        # TODO(jaszczur): this could be run after padding for better efficiency

        if self._kernel_size == 1:
            # With kernel size 1 we don't have to split or shift anything.
            linear_result = jnp.squeeze(linear_results_before_shifting,
                                        axis=-2)
        else:
            # We computed a result for every "pixel", but each direction from the
            # receptive field (there are 'self._kernel_size' such directions) must be
            # shifted by a different amount. The easiest way to do it is to split
            # the tensor to 'self._kernel_size' smaller tensors, shift each one
            # appropriately, and then sum them together.
            split_shifting_linear_results = jnp.split(
                linear_results_before_shifting, self._kernel_size, axis=-2)

            for i in range(self._kernel_size):
                # Each tensor has to be shifted a different amount.
                if self._padding == 'WRAP':
                    # We can shift by padding and cutting. With 'wrap' padding we
                    # essentially have a torus.
                    padding = [(0, 0)
                               for i in split_shifting_linear_results[i].shape]
                    padding[-3] = ((self._kernel_size - 1) - i, i)
                    split_shifting_linear_results[i] = jnp.pad(
                        split_shifting_linear_results[i], padding, mode='wrap')
                    split_shifting_linear_results[
                        i] = split_shifting_linear_results[i][
                            ..., (self._kernel_size - 1) //
                            2:-(self._kernel_size - 1) // 2, :, :]
                elif self._padding == 'SAME':
                    # We can shift by padding and cutting.
                    padding = [(0, 0)
                               for i in split_shifting_linear_results[i].shape]
                    padding[-3] = ((self._kernel_size - 1) - i, i)
                    split_shifting_linear_results[i] = jnp.pad(
                        split_shifting_linear_results[i], padding)
                    split_shifting_linear_results[
                        i] = split_shifting_linear_results[i][
                            ..., (self._kernel_size - 1) //
                            2:-(self._kernel_size - 1) // 2, :, :]
                    # TODO(jaszczur): improve efficiency by not padding things to cut
                elif self._padding == 'VALID':
                    # We don't need to shift - just cut the leftmost and rightmost values.
                    cut_left = (self._kernel_size - 1) - i
                    cut_right = split_shifting_linear_results[i].shape[-3] - i
                    split_shifting_linear_results[
                        i] = split_shifting_linear_results[i][
                            ..., cut_left:cut_right, :, :]
                else:
                    raise ValueError(f'Invalid padding {self._padding}')
            # After shifting.
            shifted_linear_results = jnp.concatenate(
                split_shifting_linear_results, axis=-2)
            linear_result = jnp.sum(shifted_linear_results, axis=-2)

        if self._use_bias:
            return linear_result + b
        else:
            return linear_result
Esempio n. 5
0
 def forward(self, xs):
     return jnp.concatenate(xs, self._axis)
Esempio n. 6
0
 def forward(self, xs):
     """Executes this layer as part of a forward pass through the model."""
     return jnp.concatenate(xs, self._axis)
def test_dot_product_self_attention(target):
    n_heads = 2
    d_head = 3
    successful_cases = 0
    failed_cases = []

    q = jnp.array([[1, 0, 0], [0, 1, 0]])
    k = jnp.array([[1, 2, 3], [4, 5, 6]])
    v = jnp.array([[0, 1, 0], [1, 0, 1]])
    m = jnp.array([[0, 0], [-1e9, 0]])

    test_cases = [
        {
            "name": "test dummy tensors",
            "input": {"q": q[None, :], "k": k[None, :], "v": v[None, :],},
            "expected": jnp.array(
                [[[0.0, 1.0, 0.0], [0.8496746, 0.15032543, 0.8496746]]]
            ),
            "error_message": [
                "Expected shape does not match",
                "Expected output does not match.",
            ],
        },
        {
            "name": "test dummy tensors",
            "input": {
                "q": jnp.array([jnp.concatenate([q, q], axis=-1)]),
                "k": jnp.array([jnp.concatenate([k, k], axis=-1)]),
                "v": jnp.array([jnp.concatenate([v, v], axis=-1)]),
            },
            "expected": jnp.array(
                [
                    [
                        [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
                        [
                            0.9205239,
                            0.07947586,
                            0.9205239,
                            0.9205239,
                            0.07947586,
                            0.9205239,
                        ],
                    ]
                ]
            ),
            "error_message": [
                "Expected shape does not match",
                "Expected output does not match.",
            ],
        },
    ]

    for test_case in test_cases:
        name = test_case.get("name")

        input_dict = test_case.get("input")
        expected = test_case.get("expected")
        output = target(**input_dict)

        try:
            assert output.shape == expected.shape
            successful_cases += 1
        except:
            print(test_case.get("error")[0])
            failed_cases.append(
                {
                    "name": test_case["name"],
                    "expected": test_case["expected"].shape,
                    "got": output.shape,
                }
            )

        try:
            assert jnp.isclose(output, expected).all()
            successful_cases += 1
        except:
            print(test_case.get("error")[1])
            failed_cases.append(
                {
                    "name": test_case["name"],
                    "expected": test_case["expected"],
                    "got": output,
                }
            )

    if len(failed_cases) == 0:
        print("\033[92m All tests passed")
    else:
        print("\033[92m", successful_cases, " Tests passed")
        print("\033[91m", len(failed_cases), " Tests failed")
def test_compute_attention_heads_closure(target):
    n_heads = 2
    d_head = 3
    successful_cases = 0
    failed_cases = []

    q = jnp.array([[1, 0, 0], [0, 1, 0]])

    test_cases = [
        {
            "name": "test dummy tensors",
            "input": {
                "x": jnp.array(
                    [jnp.concatenate([q, q], axis=-1), jnp.concatenate([q, q], axis=-1)]
                )
            },
            "expected": jnp.array(
                [
                    [[1, 0, 0], [0, 1, 0]],
                    [[1, 0, 0], [0, 1, 0]],
                    [[1, 0, 0], [0, 1, 0]],
                    [[1, 0, 0], [0, 1, 0]],
                ]
            ),
            "error_message": [
                "Expected shape does not match",
                "Expected output does not match.",
            ],
        },
        {
            "name": "test dummy tensors",
            "input": {
                "x": jnp.array(
                    [
                        jnp.concatenate([q, q], axis=-1),
                        jnp.concatenate([q, q], axis=-1),
                        jnp.concatenate([q, q], axis=-1),
                    ]
                )
            },
            "expected": jnp.array(
                [
                    [[1, 0, 0], [0, 1, 0]],
                    [[1, 0, 0], [0, 1, 0]],
                    [[1, 0, 0], [0, 1, 0]],
                    [[1, 0, 0], [0, 1, 0]],
                    [[1, 0, 0], [0, 1, 0]],
                    [[1, 0, 0], [0, 1, 0]],
                ]
            ),
            "error_message": [
                "Expected shape does not match",
                "Expected output does not match.",
            ],
        },
    ]

    for test_case in test_cases:
        name = test_case.get("name")

        input_dict = test_case.get("input")
        expected = test_case.get("expected")
        output = target(n_heads=n_heads, d_head=d_head)(**input_dict)

        try:
            assert output.shape == expected.shape
            successful_cases += 1
        except:
            print(test_case.get("error")[0])
            failed_cases.append(
                {
                    "name": test_case["name"],
                    "expected": test_case["expected"].shape,
                    "got": output.shape,
                }
            )

        try:
            assert jnp.isclose(output, expected).all()
            successful_cases += 1
        except:
            print(test_case.get("error")[1])
            failed_cases.append(
                {
                    "name": test_case["name"],
                    "expected": test_case["expected"],
                    "got": output,
                }
            )

    if len(failed_cases) == 0:
        print("\033[92m All tests passed")
    else:
        print("\033[92m", successful_cases, " Tests passed")
        print("\033[91m", len(failed_cases), " Tests failed")
Esempio n. 9
0
 def _unshard_fn(x):
     y = jax.lax.all_gather(x, 'batch', axis_index_groups=groups)
     split_y = jnp.split(y, n_shards, axis=0)
     split_y = [jnp.squeeze(sy, axis=0) for sy in split_y]
     axis = _axis_to_shard_heuristic(split_y[0].shape)
     return jnp.concatenate(split_y, axis=axis)
Esempio n. 10
0
 def rotate_half(x):
     x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
     return jnp.concatenate((-x2, x1), axis=x1.ndim - 1)
Esempio n. 11
0
 def Sinusoidal_Embeddings(positions):
   inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature))
   sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq)
   pos_emb = jnp.concatenate(
       [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1)
   return pos_emb
Esempio n. 12
0
 def pad_borders(v):
   total_len = v.shape[2]
   pre, mid, post = split_along_l(v, chunk_offset,
                                  total_len - last_chunk_len, total_len)
   pre, post = map(pad_to_chunk_len, [pre, post])
   return jnp.concatenate([pre, mid, post], axis=2)