Esempio n. 1
0
def _threefry2x32_gpu_translation_rule(ctx, avals_in, avals_out, k1, k2, x1,
                                       x2):
    aval_out, _ = avals_out
    k1_aval, k2_aval, x1_aval, x2_aval = avals_in
    rank = len(aval_out.shape)
    if 0 in aval_out.shape:
        zeros = xla_client.ops.Broadcast(
            xla_client.ops.Constant(ctx.builder, np.array(0, np.uint32)),
            aval_out.shape)
        return [zeros, zeros]

    def _broadcast(x, aval):
        return xla_client.ops.BroadcastInDim(
            x, aval_out.shape, tuple(range(rank - len(aval.shape), rank)))

    if cuda_prng:
        return xla.xla_destructure(
            ctx.builder,
            cuda_prng.threefry2x32(
                ctx.builder,
                (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
                (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval))))
    else:
        return xla.xla_destructure(
            ctx.builder,
            hip_prng.threefry2x32(
                ctx.builder,
                (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
                (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval))))
Esempio n. 2
0
def _sharded_jit_translation_rule(ctx, avals_in, avals_out, *in_nodes,
                                  in_parts, out_parts_thunk, nparts,
                                  name, call_jaxpr, local_in_parts,
                                  local_out_parts_thunk, local_nparts):
  subc = xc.XlaBuilder(f"sharded_jit_{name}")

  # We assume any extra leading in_nodes are constants and replicate them.
  num_extra_nodes = len(in_nodes) - len(in_parts)
  assert num_extra_nodes >= 0
  in_parts = (None,) * num_extra_nodes + in_parts

  args = []
  for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)):
    # We use xla.set_sharding instead of xla.with_sharding because inlined calls
    # shouldn't have shardings set directly on the inputs or outputs.
    arg = xla.parameter(subc, i, ctx.builder.GetShape(n))
    args.append(xla.set_sharding(subc, arg, sharding))

  sub_ctx = ctx.replace(
      builder=subc,
      name_stack=new_name_stack(wrap_name(name, "sharded_jit")))
  out_nodes = xla.jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
  out_parts = out_parts_thunk()
  assert len(out_parts) == len(out_nodes)
  out_nodes = [xla.set_sharding(subc, out, sharding)
               for out, sharding in safe_zip(out_nodes, out_parts)]

  subc = subc.build(xops.Tuple(subc, out_nodes))
  return xla.xla_destructure(ctx.builder,
                             xops.Call(ctx.builder, subc, list(in_nodes)))
Esempio n. 3
0
File: ann.py Progetto: 0x0is1/jax
def _approx_top_k_tpu_translation(ctx, avals_in, avals_out, operand, *, k,
                                  reduction_dimension, recall_target, is_max_k,
                                  reduction_input_size_override,
                                  aggregate_to_topk):
  c = ctx.builder
  op_shape = c.get_shape(operand)
  if not op_shape.is_array():
    raise ValueError('operand must be an array, but was {}'.format(op_shape))
  op_dims = op_shape.dimensions()
  op_type = op_shape.element_type()
  if reduction_dimension < 0:
    reduction_dimension = len(op_dims) + reduction_dimension
  comparator = _comparator_builder(operand, op_type, is_max_k)
  if is_max_k:
    if dtypes.issubdtype(op_type, np.floating):
      init_literal = np.array(np.NINF, dtype=op_type)
    else:
      init_literal = np.iinfo(op_type).min()
  else:
    if dtypes.issubdtype(op_type, np.floating):
      init_literal = np.array(np.Inf, dtype=op_type)
    else:
      init_literal = np.iinfo(op_type).max()
  iota = xc.ops.Iota(c, xc.Shape.array_shape(np.dtype(np.int32), op_dims),
                     reduction_dimension)
  init_val = xc.ops.Constant(c, init_literal)
  init_arg = xc.ops.Constant(c, np.int32(-1))
  out = xc.ops.ApproxTopK(c, [operand, iota], [init_val, init_arg], k,
                          reduction_dimension, comparator, recall_target,
                          aggregate_to_topk, reduction_input_size_override)
  return xla.xla_destructure(c, out)
Esempio n. 4
0
def _nonzero_translation_rule(c, dims, avals, operands):
  (vals,), = operands
  shape = c.get_shape(vals)
  last_axis = len(shape.dimensions()) - 1
  zeros = xops.Broadcast(xb.constant(c, np.zeros((), shape.numpy_dtype())),
                         shape.dimensions())
  s32_etype = xc.dtype_to_etype(np.dtype('int32'))
  nonzero_indicators = xops.ConvertElementType(xops.Ne(vals, zeros), s32_etype)
  i = core.ShapedArray((), np.dtype('int32'))
  out_dim = xops.Reduce(c, [nonzero_indicators],
                        [xb.constant(c, np.array(0, np.dtype('int32')))],
                        xla.primitive_subcomputation(lax.add_p, i, i),
                        (last_axis,))
  c.get_shape(out_dim)  # xla type checking

  subc = xb.make_computation_builder("sort_gt_comparator")
  params = [xb.parameter(subc, i, xc.Shape.array_shape(s32_etype, ()))
            for i in range(4)]
  comparator = subc.build(xops.Gt(params[0], params[1]))
  iota_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, shape.dimensions())
  ans = xops.Sort(c, [nonzero_indicators, xops.Iota(c, iota_shape, last_axis)],
                  is_stable=True, comparator=comparator)
  _, out_val = xla.xla_destructure(c, ans)
  c.get_shape(out_val)  # xla type checking
  return [[out_dim], [out_val]]
Esempio n. 5
0
def _pdot_translation_rule(c, x, y, *, axis_name, pos_contract, pos_batch,
                           axis_env, platform):
  assert axis_name
  local_out = lax._dot_general_translation_rule(
      c, x, y, dimension_numbers=[pos_contract, pos_batch], precision=None)
  out_tup = xla.parallel_translations[psum_p](
      c, local_out, axis_name=axis_name, axis_index_groups=None,
      axis_env=axis_env, platform=platform)
  out, = xla.xla_destructure(c, out_tup)
  return out
Esempio n. 6
0
def _reduce_window_translation_rule(ctx, avals_in, avals_out, *args, jaxpr,
                                    consts, window_dimensions, window_strides,
                                    padding, base_dilation, window_dilation):
    operands, init_values = util.split_list(args, [len(args) // 2])
    xla_computation = lax._reduction_computation(ctx,
                                                 jaxpr,
                                                 consts,
                                                 init_values,
                                                 singleton=False)
    return xla.xla_destructure(
        ctx.builder,
        xops.ReduceWindowWithGeneralPadding(operands, init_values,
                                            xla_computation, window_dimensions,
                                            window_strides, base_dilation,
                                            window_dilation, padding))
Esempio n. 7
0
def _approx_top_k_fallback_translation(ctx, avals_in, avals_out, operand, *, k,
                                       reduction_dimension, recall_target,
                                       is_max_k, reduction_input_size_override,
                                       aggregate_to_topk):
    c = ctx.builder
    op_shape = c.get_shape(operand)
    if not op_shape.is_array():
        raise ValueError(
            'operand must be an array, but was {}'.format(op_shape))
    op_dims = op_shape.dimensions()
    op_type = op_shape.element_type()

    if reduction_dimension < 0:
        reduction_dimension = len(op_dims) + reduction_dimension
    comparator = _comparator_builder(op_type, is_max_k)
    iota = xc.ops.Iota(c, xc.Shape.array_shape(np.dtype(np.int32), op_dims),
                       reduction_dimension)
    if xc._version >= 60:
        init_val_literal = _get_init_val_literal(op_type, is_max_k)
        init_val = xc.ops.Constant(c, init_val_literal)
        init_arg = xc.ops.Constant(c, np.int32(-1))
        out = xc.ops.ApproxTopKFallback(c, [operand, iota],
                                        [init_val, init_arg], k,
                                        reduction_dimension, comparator,
                                        recall_target, aggregate_to_topk,
                                        reduction_input_size_override)
        return xla.xla_destructure(c, out)
    else:
        val_arg = xc.ops.Sort(c, [operand, iota], comparator,
                              reduction_dimension)
        vals = xc.ops.GetTupleElement(val_arg, 0)
        args = xc.ops.GetTupleElement(val_arg, 1)
        sliced_vals = xc.ops.SliceInDim(
            vals, 0, avals_out[0].shape[reduction_dimension], 1,
            reduction_dimension)
        sliced_args = xc.ops.SliceInDim(
            args, 0, avals_out[0].shape[reduction_dimension], 1,
            reduction_dimension)
        return sliced_vals, sliced_args
Esempio n. 8
0
def _approx_top_k_fallback_translation(ctx, avals_in, avals_out, operand, *, k,
                                       reduction_dimension, recall_target,
                                       is_max_k, reduction_input_size_override,
                                       aggregate_to_topk):
    c = ctx.builder
    op_shape = c.get_shape(operand)
    if not op_shape.is_array():
        raise ValueError(f'operand must be an array, but was {op_shape}')
    op_dims = op_shape.dimensions()
    op_type = op_shape.element_type()

    if reduction_dimension < 0:
        reduction_dimension = len(op_dims) + reduction_dimension
    comparator = _comparator_builder(op_type, is_max_k)
    iota = xc.ops.Iota(c, xc.Shape.array_shape(np.dtype(np.int32), op_dims),
                       reduction_dimension)
    init_val_literal = _get_init_val_literal(op_type, is_max_k)
    init_val = xc.ops.Constant(c, init_val_literal)
    init_arg = xc.ops.Constant(c, np.int32(-1))
    out = xc.ops.ApproxTopKFallback(c, [operand, iota], [init_val, init_arg],
                                    k, reduction_dimension, comparator,
                                    recall_target, aggregate_to_topk,
                                    reduction_input_size_override)
    return xla.xla_destructure(c, out)