Ejemplo n.º 1
0
def _ibp_dot_general(lhs: PrimitiveInput, rhs: PrimitiveInput,
                     **kwargs) -> PrimitiveInput:
    """Propagation of IBP bounds through a general dot product.

  We don't know if the bound is on the left or right hand side, but we expect
  that one hand is a bound and the other is a constant/parameter.

  Args:
    lhs: First input to the dot primitive.
    rhs: Second input to the dot primitive.
    **kwargs: Dict with the parameters of the general dot product.
  Returns:
    out_bounds: IntervalBound on the output of the dot product.
  """
    if (isinstance(lhs, bound_propagation.Bound) != isinstance(
            rhs, bound_propagation.Bound)):
        lhses = _decompose_affine_argument(lhs)
        rhses = _decompose_affine_argument(rhs)

        forward_mean = lax.dot_general(lhses[1], rhses[1], **kwargs)
        forward_range = lax.dot_general(lhses[0], rhses[0], **kwargs)

        out_lb = forward_mean - forward_range
        out_ub = forward_mean + forward_range

        return IntervalBound(out_lb, out_ub)

    elif ((not isinstance(lhs, bound_propagation.Bound))
          and (not isinstance(rhs, bound_propagation.Bound))):
        # Both are arrays, so can simply go through
        return lax.dot_general(lhs, rhs, **kwargs)
    else:
        raise ValueError('BoundPropagation through general dot product '
                         'is not supported when both inputs are bounds.')
Ejemplo n.º 2
0
  def testDotGeneral(self):
    R = onp.random.RandomState(0).randn

    x = R(10, 3, 4, 5)
    y = R(10, 3, 5, 6)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    ans = vmap(fun)(x, y)
    expected = lax.dot_general(x, y, [((3,), (2,)), ((0, 1), (0, 1))])
    self.assertAllClose(ans, expected, check_dtypes=True)

    x = R(3, 4, 10, 5)
    y = R(3, 10, 5, 6)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    ans = vmap(fun, in_axes=(2, 1))(x, y)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    expected = onp.stack([fun(x[..., i, :], y[:, i, ...]) for i in range(10)])
    self.assertAllClose(ans, expected, check_dtypes=True)

    x = R(3, 4, 5, 10)
    y = R(3, 5, 6)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    ans = vmap(fun, in_axes=(3, None))(x, y)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    expected = onp.stack([fun(x[..., i], y) for i in range(10)])
    self.assertAllClose(ans, expected, check_dtypes=True)

    x = R(3, 4, 5)
    y = R(3, 5, 10, 6)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    ans = vmap(fun, in_axes=(None, 2))(x, y)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    expected = onp.stack([fun(x, y[..., i, :]) for i in range(10)])
    self.assertAllClose(ans, expected, check_dtypes=True)
Ejemplo n.º 3
0
    def __call__(self, x: Array) -> Array:
        """Applies the symmetrized linear transformation to the inputs along the last dimension.

        Args:
          x: The nd-array to be transformed.

        Returns:
          The transformed input.
        """
        dtype = jnp.promote_types(x.dtype, self.dtype)
        x = jnp.asarray(x, dtype)
        # infer in_features and ensure input dimensions (batch, in_features,n_sites)

        # TODO: Deprecated: Eventually remove and error if less than 3 dimensions
        if x.ndim < 3:
            old_shape = x.shape
            if x.ndim == 1:
                x = jnp.expand_dims(x, (0, 1))
            elif x.ndim == 2:
                x = jnp.expand_dims(x, 1)
            symm_input_warning(old_shape, x.shape, "DenseSymm")

        in_features = x.shape[1]

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (self.features, in_features, self.n_sites),
            self.dtype,
        )

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1))

        # Converts the convolutional kernel of shape (self.features, in_features, n_sites)
        # to a full dense kernel of shape (self.features, in_features, n_symm, n_sites).
        # result[out, in, g, r] == kernel[out, in, g^{-1}r]
        kernel = jnp.take(kernel, jnp.asarray(self.symmetries), 2)
        kernel = jnp.asarray(kernel, dtype)

        # x is      (batches,       in_featuers,         n_sites)
        # kernel is (self.features, in_features, n_symm, n_sites)
        x = lax.dot_general(
            x,
            kernel,
            (((x.ndim - 2, x.ndim - 1), (1, 3)), ((), ())),
            precision=self.precision,
        )

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features,), self.dtype)

            # Convert symmetry-reduced bias of shape (features,) to the full bias of
            # shape (..., features, 1).
            bias = jnp.expand_dims(bias, 1)
            bias = jnp.asarray(bias, dtype)

            x += bias

        return x
Ejemplo n.º 4
0
    def apply(self,
              inputs,
              features,
              bias=True,
              dtype=jnp.float32,
              precision=None,
              kernel_init=default_kernel_init,
              bias_init=initializers.zeros):
        """Applies a linear transformation to the inputs along the last dimension.

    Args:
      inputs: The nd-array to be transformed.
      features: the number of output features.
      bias: whether to add a bias to the output (default: True).
      dtype: the dtype of the computation (default: float32).
      precision: numerical precision of the computation see `jax.lax.Precision`
        for details.
      kernel_init: initializer function for the weight matrix.
      bias_init: initializer function for the bias.
    Returns:
      The transformed input.
    """
        inputs = jnp.asarray(inputs, dtype)
        kernel = self.param('kernel', (inputs.shape[-1], features),
                            kernel_init)
        kernel = jnp.asarray(kernel, dtype)
        y = lax.dot_general(inputs,
                            kernel, (((inputs.ndim - 1, ), (0, )), ((), ())),
                            precision=precision)
        if bias:
            bias = self.param('bias', (features, ), bias_init)
            bias = jnp.asarray(bias, dtype)
            y = y + bias
        return y
Ejemplo n.º 5
0
 def __call__(self, inputs, kernel):
     inputs = jnp.asarray(inputs, self.dtype)
     kernel = jnp.asarray(kernel, self.dtype)
     y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())))
     bias = jnp.asarray(self.bias, self.dtype)
     y = y + bias
     return y
Ejemplo n.º 6
0
    def test_vmap_after(self):
        batch = 4
        qy_size = 128
        db_size = 1024
        feature_dim = 32
        k = 10
        rng = jtu.rand_default(self.rng())
        qy = rng([qy_size, feature_dim, batch], np.float32)
        db = rng([db_size, feature_dim, batch], np.float32)
        recall = 0.95

        # Create ground truth
        gt_scores = lax.dot_general(qy, db, (([1], [1]), ([2], [2])))
        _, gt_args = lax.top_k(gt_scores, k)
        gt_args = lax.transpose(gt_args, [2, 0, 1])
        gt_args = lax.reshape(gt_args, [qy_size * batch, k])

        # test target
        def approx_max_k(qy, db):
            scores = qy @ db.transpose()
            return lax.approx_max_k(scores, k)

        _, ann_args = jax.vmap(approx_max_k, (2, 2))(qy, db)
        ann_args = lax.transpose(ann_args, [2, 0, 1])
        ann_args = lax.reshape(ann_args, [qy_size * batch, k])
        ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
        self.assertGreater(ann_recall, recall)
Ejemplo n.º 7
0
def dot_general_dependency_rule(outstart, outcount, lhs, rhs,
                                dimension_numbers, precision):
    if not is_ones(outcount):
        raise NotImplementedError
    outshape = outcount.shape
    outslices = list(zip(outstart, outshape))
    (lhs_contracting, rhs_contracting), (lhs_batch,
                                         rhs_batch) = dimension_numbers
    lhs_other_out_dims = list(
        range(len(lhs_batch),
              len(lhs.shape) - len(lhs_contracting)))
    rhs_other_out_dims = list(
        range(len(rhs_batch) + len(lhs_other_out_dims), len(outshape)))
    lhs_outstart, lhs_outshape = unzip2(
        [outslices[d] for d in list(lhs_batch) + lhs_other_out_dims])
    (lhs_box, ), (lhs_count, ), _ = reduce_dependency_rule(None)(
        lhs_outstart, Ones(lhs_outshape), lhs, axes=lhs_contracting)
    rhs_outstart, rhs_outshape = unzip2(
        [outslices[d] for d in list(rhs_batch) + rhs_other_out_dims])
    (rhs_box, ), (rhs_count, ), _ = reduce_dependency_rule(None)(
        rhs_outstart, Ones(rhs_outshape), rhs, axes=rhs_contracting)
    incounts = [
        materialize(lhs_count) * prod(np.take(outshape, rhs_other_out_dims))
        if isinstance(lhs, LazyArray) else None,
        materialize(rhs_count) * prod(np.take(outshape, lhs_other_out_dims))
        if isinstance(rhs, LazyArray) else None
    ]
    return ([lhs_box, rhs_box], incounts, lambda *inslices: lax.dot_general(
        *inslices, dimension_numbers, precision))
def generalized_kernel_feature_creator(data, projection_matrix, batch_dims_t,
                                       precision, kernel_fn, kernel_epsilon,
                                       normalize_data):
  """Constructs kernel features for fast generalized attention.
  Args:
    data: input for which features are computes
    projection_matrix: matrix used to compute features
    batch_dims_t: tuple of batch dimensions
    precision: precision parameter
    kernel_fn: kernel function used
    kernel_epsilon: additive positive term added to every feature for numerical
      stability
    normalize_data: predicate indicating whether data should be normalized.
  Returns:
    Random features for fast generalized attention.
  """
  if normalize_data:
    data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
  else:
    data_normalizer = 1.0
  if projection_matrix is None:
    return kernel_fn(data_normalizer * data) + kernel_epsilon
  else:
    data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape
    data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
    data_dash = lax.dot_general(
        data_normalizer * data,
        data_thick_random_matrix,
        (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)),
         (batch_dims_t, batch_dims_t)),
        precision=precision)
  data_prime = kernel_fn(data_dash) + kernel_epsilon
  return data_prime
Ejemplo n.º 9
0
  def testPdotVJPSystematic(self, lhs_shape, rhs_shape, pdot_spec,
                            axis_resources, mesh_data):
    rng = jtu.rand_default(self.rng())
    lhs = rng(lhs_shape, np.float32)
    rhs = rng(rhs_shape, np.float32)

    expected_out, ref_vjp = jax.vjp(
        lambda x, y: lax.dot_general(x, y, pdot_spec.dot_general_dim_nums),
        lhs, rhs)
    out_bar = rng(expected_out.shape, np.float32)
    expected_lhs, expected_rhs = ref_vjp(out_bar)

    def pdot_fun(x, y, out_bar):
      pdot = partial(jax.lax.pdot,
                     axis_name=pdot_spec.contract_names,
                     pos_batch=pdot_spec.pos_batch_after_mapping,
                     pos_contract=pdot_spec.pos_contract_after_mapping)
      _, pdot_vjp = jax.vjp(pdot, x, y)
      return pdot_vjp(out_bar)

    fun = xmap(pdot_fun,
               in_axes=[pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes,
                        [*pdot_spec.batch_names, ...]],
               out_axes=(pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes),
               axis_resources=axis_resources)

    with with_mesh(mesh_data):
      lhs_bar, rhs_bar = fun(lhs, rhs, out_bar)

    tol = 1e-1 if jtu.device_under_test() == "tpu" else None
    self.assertAllClose(lhs_bar, expected_lhs, check_dtypes=False,
                        atol=tol, rtol=tol)
    self.assertAllClose(rhs_bar, expected_rhs, check_dtypes=False,
                        atol=tol, rtol=tol)
Ejemplo n.º 10
0
  def testPdotSystematic(self, lhs_shape, rhs_shape, pdot_spec, axis_resources,
                         mesh_data):
    rng = jtu.rand_default(self.rng())
    lhs = rng(lhs_shape, np.float32)
    rhs = rng(rhs_shape, np.float32)

    def pdot_fun(x, y):
      # print(f'pdot(x:{x.aval.str_short()}, y:{y.aval.str_short()},\n'
      #       f'     axis_name={contract_names},\n'
      #       f'     pos_contract={spec.pos_contract_after_mapping}\n'
      #       f'     pos_batch={spec.pos_batch_after_mapping})')
      return jax.lax.pdot(x, y, axis_name=pdot_spec.contract_names,
                          pos_batch=pdot_spec.pos_batch_after_mapping,
                          pos_contract=pdot_spec.pos_contract_after_mapping)

    fun = xmap(pdot_fun, in_axes=[pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes],
               out_axes=[*pdot_spec.batch_names, ...],
               axis_resources=axis_resources)

    with with_mesh(mesh_data):
      result = fun(lhs, rhs)

    expected = lax.dot_general(lhs, rhs, pdot_spec.dot_general_dim_nums)
    tol = 1e-1 if jtu.device_under_test() == "tpu" else None
    self.assertAllClose(result, expected, check_dtypes=False,
                        atol=tol, rtol=tol)
Ejemplo n.º 11
0
def nonnegative_softmax_kernel_feature_creator(data,
                                               projection_matrix,
                                               attention_dims_t,
                                               batch_dims_t,
                                               precision,
                                               is_query,
                                               normalize_data=True,
                                               eps=0.0001):
  """Constructs nonnegative kernel features for fast softmax attention.


  Args:
    data: input for which features are computes
    projection_matrix: random matrix used to compute features
    attention_dims_t: tuple of attention dimensions
    batch_dims_t: tuple of batch dimensions
    precision: precision parameter
    is_query: predicate indicating whether input data corresponds to queries or
      keys
    normalize_data: predicate indicating whether data should be normalized,
    eps: numerical stabilizer.

  Returns:
    Random features for fast softmax attention.
  """

  if normalize_data:
    # We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where
    # w_norm = w * data_normalizer for w in {q,k}.
    data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
  else:
    data_normalizer = 1.0
  ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0])
  data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape
  data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix

  data_dash = lax.dot_general(
      data_normalizer * data,
      data_thick_random_matrix,
      (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)),
       (batch_dims_t, batch_dims_t)),
      precision=precision)

  diag_data = jnp.square(data)
  diag_data = jnp.sum(diag_data, axis=data.ndim - 1)
  diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
  diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1)

  last_dims_t = (len(data_dash.shape) - 1,)
  if is_query:
    data_dash = ratio * (
        jnp.exp(data_dash - diag_data -
                jnp.max(data_dash, axis=last_dims_t, keepdims=True)) + eps)
  else:
    data_dash = ratio * (
        jnp.exp(data_dash - diag_data - jnp.max(
            data_dash, axis=last_dims_t + attention_dims_t, keepdims=True)) +
        eps)

  return data_dash
Ejemplo n.º 12
0
    def __call__(self, inputs: Array) -> Array:
        """Applies a linear transformation to the inputs along the last dimension.
        Args:
          inputs: The nd-array to be transformed.
        Returns:
          The transformed input.
        """
        dtype = jnp.promote_types(inputs.dtype, self.dtype)

        inputs = jnp.asarray(inputs, dtype)
        kernel = self.param(
            "kernel", self.kernel_init, (inputs.shape[-1], self.features), self.dtype
        )
        kernel = jnp.asarray(kernel, dtype)
        y = lax.dot_general(
            inputs,
            kernel,
            (((inputs.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )
        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features,), self.dtype)
            bias = jnp.asarray(bias, dtype)
            y = y + bias
        return y
Ejemplo n.º 13
0
    def __call__(self, query, key, value, bias=None, dtype=jnp.float32):
        assert key.ndim == query.ndim
        assert key.ndim == value.ndim

        n = query.ndim
        attn_weights = lax.dot_general(query, key,
                                       (((n - 1, ), (n - 1, )), ((), ())))
        if bias is not None:
            attn_weights += bias
        attn_weights = self.attn_module()(attn_weights)
        attn_weights = attn_weights.astype(dtype)

        contract_dims = (tuple(range(n - 1, attn_weights.ndim)),
                         tuple(range(0, n - 1)))
        y = lax.dot_general(attn_weights, value, (contract_dims, ((), ())))
        return y
Ejemplo n.º 14
0
 def __call__(self, x, kernel):
     y = lax.dot_general(
         x,
         kernel,
         (((x.ndim - 1, ), (0, )), ((), ())),
         precision=self.precision,
     )
     return y + self.bias
Ejemplo n.º 15
0
    def __call__(self, x: Array) -> Array:
        """Applies the equivariant transform to the inputs along the last two
        dimensions (-2: features, -1: group elements)
        """
        in_features = x.shape[-2]

        x = x.reshape(*x.shape[:-1], self.n_cells, self.n_point)
        x = x.transpose(0, 1, 3, 2)
        x = x.reshape(*x.shape[:-1], *self.shape)

        if self.use_bias:
            bias = self.param(
                "bias", self.bias_init, (self.features,), self.param_dtype
            )
        else:
            bias = None

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (self.features, in_features, self.n_point * self.n_cells),
            self.param_dtype,
        )

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1))

        x, kernel, bias = promote_dtype(x, kernel, bias, dtype=None)
        dtype = x.dtype

        # Convert the convolutional kernel of shape (features, in_features, n_symm)
        # to the expanded kernel of shape (features, in_features, n_point(in),
        # n_point(out), *shape) used in FFT-based group convolutions
        kernel = kernel[..., self.mapping]

        x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)

        kernel = jnp.fft.fftn(kernel, s=self.shape).reshape(
            *kernel.shape[:4], self.n_cells
        )

        x = lax.dot_general(
            x, kernel, (((1, 2), (1, 2)), ((3,), (4,))), precision=self.precision
        )
        x = x.transpose(1, 2, 3, 0)
        x = x.reshape(*x.shape[:3], *self.shape)

        x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)
        x = x.transpose(0, 1, 3, 2)
        x = x.reshape(*x.shape[:2], -1)

        if self.use_bias:
            x += jnp.expand_dims(bias, (0, 2))

        if jnp.can_cast(x, dtype):
            return x
        else:
            return x.real
Ejemplo n.º 16
0
 def __call__(self, inputs):
   kernel = self.param('kernel', self.kernel_init,
                       (inputs.shape[-1], self.features))
   y = lax.dot_general(inputs, kernel,
                       (((inputs.ndim - 1,), (0,)), ((), ())),)
   if self.use_bias:
     bias = self.param('bias', self.bias_init, (self.features,))
     y = y + bias
   return y
Ejemplo n.º 17
0
def _cov_full_batch_diag_spatial(x1: np.ndarray, x2: np.ndarray,
                                 batch_axis: int,
                                 channel_axis: int) -> np.ndarray:
    diag_axes = tuple(i for i in range(x1.ndim)
                      if i != batch_axis and i != channel_axis)
    ret = lax.dot_general(x1, x2, (((channel_axis, ), (channel_axis, )),
                                   (diag_axes, diag_axes)))
    ret = np.moveaxis(ret, (-2, -1), (0, 1))
    return ret
Ejemplo n.º 18
0
    def __call__(self, x: Array) -> Array:
        """Applies the equivariant transform to the inputs along the last two
        dimensions (-2: features, -1: group elements)
        """

        dtype = jnp.promote_types(x.dtype, self.dtype)
        x = jnp.asarray(x, dtype)

        x = x.reshape(*x.shape[:-1], self.n_cells, self.n_point)
        x = x.transpose(0, 1, 3, 2)
        x = x.reshape(*x.shape[:-1], *self.shape)

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (
                self.out_features,
                self.in_features,
                self.n_point * self.n_cells,
            ),
            self.dtype,
        )

        kernel = jnp.asarray(kernel, dtype)

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.mask, (0, 1))

        kernel = self.make_kernel(kernel)

        x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)

        kernel = jnp.fft.fftn(kernel,
                              s=self.shape).reshape(*kernel.shape[:4],
                                                    self.n_cells)

        x = lax.dot_general(x,
                            kernel, (((1, 2), (1, 2)), ((3, ), (4, ))),
                            precision=self.precision)
        x = x.transpose(1, 2, 3, 0)
        x = x.reshape(*x.shape[:3], *self.shape)

        x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)
        x = x.transpose(0, 1, 3, 2)
        x = x.reshape(*x.shape[:2], -1)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.out_features, ),
                              self.dtype)
            bias = jnp.asarray(bias, dtype)
            x += jnp.expand_dims(bias, (0, 2))

        if jnp.can_cast(x, dtype):
            return x
        else:
            return x.real
Ejemplo n.º 19
0
 def parallel_topk(qy, db, db_offset):
     scores = lax.dot_general(qy, db, (([1], [1]), ([], [])))
     ann_vals, ann_args = lax.approx_min_k(
         scores,
         k=k,
         reduction_dimension=1,
         recall_target=recall,
         reduction_input_size_override=db_size,
         aggregate_to_topk=False)
     return (ann_vals, ann_args + db_offset)
Ejemplo n.º 20
0
 def __call__(self, inputs):
     inputs = jnp.asarray(inputs, self.dtype)
     kernel = self.param("kernel", jax.nn.initializers.normal(stddev=0.02), (self.features, inputs.shape[-1]))
     kernel = jnp.asarray(kernel.transpose(), self.dtype)
     y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision)
     if self.use_bias:
         bias = self.param("bias", jax.nn.initializers.zeros, (self.features,))
         bias = jnp.asarray(bias, self.dtype)
         y = y + bias
     return y
Ejemplo n.º 21
0
 def __call__(self, inputs):
     kernel = self.param(
         "kernel", self.kernel_init, (self.n_tasks, inputs.shape[-1], self.features)
     )
     y = lax.dot_general(
         inputs, kernel, dimension_numbers=(((2,), (1,)), ((0,), (0,)))
     )
     bias = self.param("bias", self.bias_init, (self.n_tasks, 1, self.features))
     y = y + bias
     return y
Ejemplo n.º 22
0
  def __call__(self, inputs: Array) -> Array:
    """Applies a linear transformation to the inputs along multiple dimensions.

    Args:
      inputs: The nd-array to be transformed.

    Returns:
      The transformed input.
    """
    inputs = jnp.asarray(inputs, self.dtype)

    ndim = inputs.ndim
    n_batch_dims = len(self.batch_dims)
    axis = _normalize_axes(self.axis, ndim)
    batch_dims = _normalize_axes(self.batch_dims, ndim)
    n_axis, n_features = len(axis), len(self.features)

    def kernel_init_wrap(rng, shape, dtype=jnp.float32):
      size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32)
      flat_shape = (np.prod(shape[n_batch_dims:n_axis + n_batch_dims]),
                    np.prod(shape[-n_features:]),)
      kernel = jnp.concatenate([self.kernel_init(rng, flat_shape, dtype)
                                for _ in range(size_batch_dims)], axis=0)
      return jnp.reshape(kernel, shape)

    batch_shape = tuple([inputs.shape[ax] for ax in batch_dims])
    kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + self.features
    kernel = self.param('kernel', kernel_init_wrap, batch_shape + kernel_shape)
    kernel = jnp.asarray(kernel, self.dtype)

    batch_ind = tuple(range(n_batch_dims))
    contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims))
    out = lax.dot_general(inputs,
                          kernel,
                          ((axis, contract_ind), (batch_dims, batch_ind)),
                          precision=self.precision)
    if self.use_bias:
      def bias_init_wrap(rng, shape, dtype=jnp.float32):
        size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32)
        flat_shape = (np.prod(shape[-n_features:]),)
        bias = jnp.concatenate([self.bias_init(rng, flat_shape, dtype)
                                for _ in range(size_batch_dims)], axis=0)
        return jnp.reshape(bias, shape)

      bias = self.param('bias', bias_init_wrap, batch_shape + self.features)

      # Reshape bias for broadcast.
      expand_dims = sorted(
          set(range(inputs.ndim)) - set(axis) - set(batch_dims))
      for ax in expand_dims:
        bias = jnp.expand_dims(bias, ax)
      bias = jnp.asarray(bias, self.dtype)
      out = out + bias
    return out
Ejemplo n.º 23
0
 def dot_general_int(ops):
     lhs_, rhs_ = ops
     input_dtype = lhs_.dtype
     lhs_int = lhs_.astype(jnp.int8)
     rhs_int = rhs_.astype(jnp.int8)
     return lax.dot_general(
         lhs_int,
         rhs_int,
         dimension_numbers=dimension_numbers,
         precision=dot_precision,
         preferred_element_type=jnp.int32).astype(input_dtype)
 def __call__(self, x, kernel):
     x = jnp.asarray(x, self.dtype)
     kernel = jnp.asarray(kernel, self.dtype)
     y = lax.dot_general(
         x,
         kernel,
         (((x.ndim - 1, ), (0, )), ((), ())),
         precision=self.precision,
     )
     bias = jnp.asarray(self.bias, self.dtype)
     return y + bias
Ejemplo n.º 25
0
 def low_rank_projection(inputs, kernel, precision):
     """low rank projection."""
     input_dim = inputs.shape[1]
     # this kernel/parameter relies on sequence length
     kernel = kernel[:input_dim, :]
     inputs = inputs.transpose((0, 3, 2, 1))
     y = lax.dot_general(inputs,
                         kernel,
                         (((inputs.ndim - 1, ), (0, )), ((), ())),
                         precision=precision)
     y = y.transpose((0, 3, 2, 1))
     return y
def sincos_softmax_kernel_feature_creator(data,
                                          projection_matrix,
                                          attention_dims_t,
                                          batch_dims_t,
                                          precision,
                                          normalize_data=True):
    """
    Constructs kernel sin-cos features for fast softmax attention

    Args:
      data: input for which features are computes
      projection_matrix: random matrix used to compute features
      attention_dims_t: tuple of attention dimensions
      batch_dims_t: tuple of batch dimensions
      precision: precision parameter
      normalize_data: predicate indicating whether data should be normalized

    Returns:
      Random features for fast softmax attention.
    """
    if normalize_data:
        # We have: exp(qk^T/sqrt{d}) = exp(|q|^2/2sqrt{d}) * exp(|k|^2/2sqrt{d}) *
        # exp(-(|q*c-k*c|^2)/2), where c = 1.0 / sqrt{sqrt{d}}.
        data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
    else:
        data_normalizer = 1.0
    ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0])
    data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape
    data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix

    data_dash = lax.dot_general(
        data_normalizer * data,
        data_thick_random_matrix,
        (((data.ndim - 1, ), (data_thick_random_matrix.ndim - 1, )),
         (batch_dims_t, batch_dims_t)),
        precision=precision,
    )
    data_dash_cos = ratio * jnp.cos(data_dash)
    data_dash_sin = ratio * jnp.sin(data_dash)
    data_dash = jnp.concatenate((data_dash_cos, data_dash_sin), axis=-1)

    # Constructing D_data and data^{'}
    diag_data = jnp.square(data)
    diag_data = jnp.sum(diag_data, axis=data.ndim - 1)
    diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
    diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1)
    # Additional renormalization for numerical stability
    data_renormalizer = jnp.max(diag_data, attention_dims_t, keepdims=True)
    diag_data -= data_renormalizer
    diag_data = jnp.exp(diag_data)
    data_prime = data_dash * diag_data
    return data_prime
Ejemplo n.º 27
0
 def __call__(self, inputs):
     inputs = jnp.asarray(inputs, self.dtype)
     kernel = self.param('kernel', self.kernel_init,
                         (inputs.shape[-1], self.features))
     kernel = jnp.asarray(kernel, self.dtype)
     y = lax.dot_general(inputs,
                         kernel, (((inputs.ndim - 1, ), (0, )), ((), ())),
                         precision=self.precision)
     if self.use_bias:
         bias = self.param('bias', self.bias_init, (self.features, ))
         bias = jnp.asarray(bias, self.dtype)
         y = y + bias
     return y
Ejemplo n.º 28
0
def _dot_product_attention(scope: Scope,
                           query: Array,
                           key: Array,
                           value: Array,
                           bias: Optional[Array] = None,
                           attn_fn: Callable = softmax_attn,
                           dtype=jnp.float32):
    assert key.ndim == query.ndim
    assert key.ndim == value.ndim

    n = query.ndim
    attn_weights = lax.dot_general(query, key,
                                   (((n - 1, ), (n - 1, )), ((), ())))
    if bias is not None:
        attn_weights += bias
    attn_weights = attn_fn(scope, attn_weights)
    attn_weights = attn_weights.astype(dtype)

    contract_dims = (tuple(range(n - 1,
                                 attn_weights.ndim)), tuple(range(0, n - 1)))
    y = lax.dot_general(attn_weights, value, (contract_dims, ((), ())))
    return y
Ejemplo n.º 29
0
def dot_general(lhs: np.ndarray,
                rhs: np.ndarray,
                contracting_dims: Axes,
                batch_dims: Axes,
                precision=None) -> np.ndarray:
    """`jax.lax.dot_general` with preserved dims order and shared lhs / rhs dims.

  Precisely, returns `jax.lax.dot_general(lhs, rhs, dimension_numbers)` where
  `dimension_numbers == ((contracting_dims, contracting_dims),
                         (batch_dims, batch_dims))`,
  but allows arbitrary dimension order and preserves it in the output. See XLA's
   `DotGeneral<https://www.tensorflow.org/xla/operation_semantics#dotgeneral>`.

  Args:
    lhs: np.ndarray.
    rhs: np.ndarray.
    contracting_dims: contracting dimensions.
    batch_dims: batch dimensions.
    precision: Optional. Either `None`, which means the default precision for
      the backend, or a `Precision` enum value.

  Returns:
    Dot product result with preserved dimension order.
  """
    contracting_dims = canonicalize_axis(contracting_dims, lhs)
    batch_dims = canonicalize_axis(batch_dims, lhs)

    n_batch_dims = len(batch_dims)
    leading_batch_dims = range(n_batch_dims)

    lhs = np.moveaxis(lhs, batch_dims, leading_batch_dims)
    if rhs is None:
        rhs = lhs
    else:
        rhs = np.moveaxis(rhs, batch_dims, leading_batch_dims)

    shifted_contracting_dims = [
        i + sum(1 if i < b else 0 for b in batch_dims)
        for i in contracting_dims
    ]

    dimension_numbers = ((shifted_contracting_dims, shifted_contracting_dims),
                         (leading_batch_dims, leading_batch_dims))

    prod = lax.dot_general(lhs, rhs, dimension_numbers, precision)
    prod = zip_axes(prod, n_batch_dims)

    res_batch_dims = get_res_batch_dims(contracting_dims, batch_dims)
    prod = np.moveaxis(prod, leading_batch_dims, res_batch_dims)
    return prod
Ejemplo n.º 30
0
    def __call__(self, x: Array) -> Array:
        """Applies the equivariant transform to the inputs along the last dimension.
        Args:
          x: The nd-array to be transformed.
        Returns:
          The transformed input.
        """
        in_features = x.shape[-2]

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (self.features, in_features, self.n_symm),
            self.param_dtype,
        )

        if self.use_bias:
            bias = self.param(
                "bias", self.bias_init, (self.features,), self.param_dtype
            )
        else:
            bias = None

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1))

        kernel, bias, x = promote_dtype(kernel, bias, x, dtype=None)

        # Converts the convolutional kernel of shape (features, in_features, n_symm)
        # to a full dense kernel of shape (features, in_features, n_symm, n_symm)
        # result[out, in, g, h] == kernel[out, in, g^{-1}h]
        # input dimensions are [in, g], output dimensions are [out, h]
        kernel = jnp.take(kernel, jnp.asarray(self.product_table), 2)

        x = lax.dot_general(
            x,
            kernel,
            (((x.ndim - 2, x.ndim - 1), (1, 2)), ((), ())),
            precision=self.precision,
        )

        x = x.reshape(-1, self.features, self.n_symm)

        if self.use_bias:
            x += jnp.expand_dims(bias, 1)

        return x