示例#1
0
def _nonzero_translation_rule(c, dims, avals, operands):
  (vals,), = operands
  shape = c.get_shape(vals)
  last_axis = len(shape.dimensions()) - 1
  zeros = xops.Broadcast(xb.constant(c, np.zeros((), shape.numpy_dtype())),
                         shape.dimensions())
  s32_etype = xc.dtype_to_etype(np.dtype('int32'))
  nonzero_indicators = xops.ConvertElementType(xops.Ne(vals, zeros), s32_etype)
  i = core.ShapedArray((), np.dtype('int32'))
  out_dim = xops.Reduce(c, [nonzero_indicators],
                        [xb.constant(c, np.array(0, np.dtype('int32')))],
                        xla.primitive_subcomputation(lax.add_p, i, i),
                        (last_axis,))
  c.get_shape(out_dim)  # xla type checking

  subc = xb.make_computation_builder("sort_gt_comparator")
  params = [xb.parameter(subc, i, xc.Shape.array_shape(s32_etype, ()))
            for i in range(4)]
  comparator = subc.build(xops.Gt(params[0], params[1]))
  iota_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, shape.dimensions())
  ans = xops.Sort(c, [nonzero_indicators, xops.Iota(c, iota_shape, last_axis)],
                  is_stable=True, comparator=comparator)
  _, out_val = xla.xla_destructure(c, ans)
  c.get_shape(out_val)  # xla type checking
  return [[out_dim], [out_val]]
示例#2
0
文件: lax_linalg.py 项目: yotarok/jax
def _nan_like(c, operand):
    shape = c.GetShape(operand)
    dtype = shape.element_type()
    if np.issubdtype(dtype, onp.complexfloating):
        nan = xb.constant(c, onp.array(onp.nan * (1. + 1j), dtype=dtype))
    else:
        nan = xb.constant(c, onp.array(onp.nan, dtype=dtype))
    return xops.Broadcast(nan, shape.dimensions())
示例#3
0
def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
  axis_pos = list(axis_env.names).index(axis_name)
  nreplicas = axis_env.nreps // prod(axis_env.sizes)
  div = xb.constant(c, np.array(nreplicas * prod(axis_env.sizes[axis_pos+1:]),
                                dtype=np.uint32))
  mod = xb.constant(c, np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
  unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
  return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
示例#4
0
文件: lax_linalg.py 项目: yotarok/jax
def _triangular_solve_cpu_translation_rule(c, a, b, left_side, lower,
                                           transpose_a, conjugate_a,
                                           unit_diagonal):
    shape = c.GetShape(a)
    dtype = shape.element_type().type

    if conjugate_a and not transpose_a:
        a = xops.Conj(a)
        conjugate_a = False
    if len(shape.dimensions()) == 2 and onp.dtype(dtype) in _cpu_lapack_types:
        return lapack.jax_trsm(xb.computation_builder_shim(c),
                               xb.constant(c, onp.array(1, dtype=dtype)), a, b,
                               left_side, lower, transpose_a, conjugate_a,
                               unit_diagonal)
    else:
        # Fall back to the HLO implementation for unsupported types or batching.
        # TODO: Consider swapping XLA for LAPACK in batched case
        if not transpose_a:
            transpose = xops.TriangularSolveOptions_Transpose.NO_TRANSPOSE
        else:
            transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT
                         if conjugate_a else
                         xops.TriangularSolveOptions_Transpose.TRANSPOSE)
        return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal,
                                    transpose)
示例#5
0
def _reduce_sum_translation_rule(c, dims, avals, operands, *, axes):
  (x,), = operands
  shape = c.get_shape(x)
  dtype = shape.numpy_dtype()
  iota_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, shape.dimensions())
  if dims:
    aval, = avals
    masks = [xops.Lt(xops.Iota(c, iota_shape, i), dims[v][0])
             for i, v in enumerate(aval.shape) if isinstance(v, Var)
             and i in axes]
    map(c.get_shape, masks)
    x = xops.Select(reduce(xops.And, masks), x,
                    xops.Broadcast(xb.constant(c, np.zeros((), dtype)),
                                   shape.dimensions()))
  scalar = core.ShapedArray((), dtype)
  out = xops.Reduce(c, [x], [xb.constant(c, np.array(0, dtype))],
                    xla.primitive_subcomputation(lax.add_p, scalar, scalar), axes)
  return [[out]]
示例#6
0
 def test_error_different_shapes(self):
     """Try to register different shapes for the same consumer ID."""
     comp = xla_bridge.make_computation_builder(self._testMethodName)
     token = hcb.xops.CreateToken(comp)
     hcb._outfeed_receiver.receiver.add_outfeed(
         comp, token, 123,
         [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))])
     with self.assertRaisesRegex(
             RuntimeError,
             ".*does not match previous shape element_type.*"):
         hcb._outfeed_receiver.receiver.add_outfeed(
             comp, token, 123,
             [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.int32))])
     with self.assertRaisesRegex(
             RuntimeError,
             ".*does not match previous shape element_type.*"):
         hcb._outfeed_receiver.receiver.add_outfeed(
             comp, token, 123,
             [xla_bridge.constant(comp, np.zeros((2, ), dtype=np.float32))])
示例#7
0
    def test_error_bad_consumer_id(self):
        """Try to use reserved consumer ID 0.

    Check that we get the proper error from the runtime."""
        comp = xla_bridge.make_computation_builder(self._testMethodName)
        token = hcb.xops.CreateToken(comp)
        with self.assertRaisesRegex(
                RuntimeError, "Consumer ID cannot be a reserved value: 0"):
            hcb._outfeed_receiver.receiver.add_outfeed(comp, token, 0, [
                xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))
            ])
示例#8
0
def _threefry2x32_gpu_translation_rule(c, k1, k2, x1, x2):
    shape = lax.broadcast_shapes(
        c.get_shape(k1).dimensions(),
        c.get_shape(k2).dimensions(),
        c.get_shape(x1).dimensions(),
        c.get_shape(x2).dimensions())
    rank = len(shape)
    if 0 in shape:
        zeros = xla_client.ops.Broadcast(
            xla_bridge.constant(c, np.array(0, np.uint32)), shape)
        return xla_client.ops.Tuple(c, [zeros, zeros])

    def _broadcast(x):
        ndims = c.get_shape(x).rank()
        return xla_client.ops.BroadcastInDim(x, shape,
                                             tuple(range(rank - ndims, rank)))

    return cuda_prng.threefry2x32(c, (_broadcast(k1), _broadcast(k2)),
                                  (_broadcast(x1), _broadcast(x2)))
示例#9
0
 def _axis_index_translation_rule(c, nreps, sizes, axis_name):
   div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32))
   mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32))
   unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
   return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
示例#10
0
 def read(v):
   if type(v) is core.Literal:
     return [xb.constant(c, xla.canonicalize_dtype(v.val))]
   else:
     return env[v]
示例#11
0
def _xla_consts(c, consts):
  unique_consts = {id(const): const for const in consts}
  xla_consts = {
      id_: [xb.constant(c, const)] for id_, const in unique_consts.items()}
  return [xla_consts[id(const)] for const in consts]