Exemplo n.º 1
0
  def testSelectAndScatterAdd(self, dtype, padding, shape, dims, strides):
    rng = jtu.rand_small(self.rng())

    pads = lax.padtype_to_pads(shape, dims, strides, padding)

    def fun(operand, cotangents):
      return lax._select_and_scatter_add(operand, cotangents, lax.ge_p, dims,
                                         strides, pads)
    ones = (1,) * len(shape)
    cotangent_shape = api.eval_shape(
      lambda x: lax._select_and_gather_add(x, x, lax.ge_p, dims, strides,
                                           pads, ones, ones),
      np.ones(shape, dtype)).shape

    for bdims in all_bdims(cotangent_shape, shape):
      self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape),
                          (dtype, dtype), rng)
Exemplo n.º 2
0
 def fun(operand, tangents):
     pads = lax.padtype_to_pads(operand.shape, dims, strides, padding)
     ones = (1, ) * len(operand.shape)
     return lax._select_and_gather_add(operand, tangents, lax.ge_p,
                                       dims, strides, pads, ones, ones)
Exemplo n.º 3
0
 def fun(operand, tangents):
     return lax._select_and_gather_add(operand, tangents, lax.ge_p,
                                       dims, strides, padding)