Ejemplo n.º 1
0
 def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype,
                                lhs_contracting, rhs_contracting, bdims):
     rng = jtu.rand_small(self.rng())
     dimension_numbers = ((lhs_contracting, rhs_contracting), ([], []))
     dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
     self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape),
                         (dtype, dtype), rng)
Ejemplo n.º 2
0
 def testVariadicReduce(self, shape, dtype, dims, bdims):
   def op(a, b):
     x1, y1 = a
     x2, y2 = b
     return x1 + x2, y1 * y2
   rng = jtu.rand_small(self.rng())
   init_val = tuple(np.asarray([0, 1], dtype=dtype))
   fun = lambda x, y: lax.reduce((x, y), init_val, op, dims)
   self._CheckBatching(fun, 5, bdims, (shape, shape), (dtype, dtype), rng,
                       multiple_results=True)
Ejemplo n.º 3
0
    def testReduceWindow(self, op, init_val, dtype, shape, dims, strides,
                         padding, base_dilation, window_dilation):
        rng = jtu.rand_small(self.rng())
        init_val = np.asarray(init_val, dtype=dtype)

        def fun(operand):
            return lax.reduce_window(operand, init_val, op, dims, strides,
                                     padding, base_dilation, window_dilation)

        for bdims in all_bdims(shape):
            self._CheckBatching(fun, 3, bdims, (shape, ), (dtype, ), rng)
Ejemplo n.º 4
0
    def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
                                       dimension_numbers, bdims):
        rng = jtu.rand_small(self.rng())
        dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
        self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape),
                            (dtype, dtype), rng)

        # Checks that batching didn't introduce any transposes or broadcasts.
        jaxpr = jax.make_jaxpr(dot)(np.zeros(lhs_shape, dtype),
                                    np.zeros(rhs_shape, dtype))
        for eqn in jtu.iter_eqns(jaxpr.jaxpr):
            self.assertFalse(eqn.primitive in ["transpose", "broadcast"])
Ejemplo n.º 5
0
  def testSelectAndScatterAdd(self, dtype, padding, shape, dims, strides):
    rng = jtu.rand_small(self.rng())

    pads = lax.padtype_to_pads(shape, dims, strides, padding)

    def fun(operand, cotangents):
      return lax._select_and_scatter_add(operand, cotangents, lax.ge_p, dims,
                                         strides, pads)
    ones = (1,) * len(shape)
    cotangent_shape = jax.eval_shape(
      lambda x: lax._select_and_gather_add(x, x, lax.ge_p, dims, strides,
                                           pads, ones, ones),
      np.ones(shape, dtype)).shape

    for bdims in all_bdims(cotangent_shape, shape):
      self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape),
                          (dtype, dtype), rng)
Ejemplo n.º 6
0
    def testSelectAndGatherAdd(self, dtype, padding):
        rng = jtu.rand_small(self.rng())
        all_configs = itertools.chain(
            itertools.product([(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1),
                                                           (1, 2)]),
            itertools.product([(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
                              [(1, 2, 2, 1), (1, 1, 1, 1)]))

        def fun(operand, tangents):
            pads = lax.padtype_to_pads(operand.shape, dims, strides, padding)
            ones = (1, ) * len(operand.shape)
            return lax_windowed_reductions._select_and_gather_add(
                operand, tangents, lax.ge_p, dims, strides, pads, ones, ones)

        for shape, dims, strides in all_configs:
            for bdims in all_bdims(shape, shape):
                self._CheckBatching(fun, 3, bdims, (shape, shape),
                                    (dtype, dtype), rng)
Ejemplo n.º 7
0
 def testReduce(self, op, init_val, shape, dtype, dims, bdims):
     rng = jtu.rand_small(self.rng())
     init_val = np.asarray(init_val, dtype=dtype)
     fun = lambda operand: lax.reduce(operand, init_val, op, dims)
     self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng)
Ejemplo n.º 8
0
 def testPad(self, shape, dtype, pads, bdims):
     rng = jtu.rand_small(self.rng())
     fun = lambda operand, padding: lax.pad(operand, padding, pads)
     self._CheckBatching(fun, 5, bdims, (shape, ()), (dtype, dtype), rng)