def _threefry2x32_gpu_translation_rule(ctx, avals_in, avals_out, k1, k2, x1, x2): aval_out, _ = avals_out k1_aval, k2_aval, x1_aval, x2_aval = avals_in rank = len(aval_out.shape) if 0 in aval_out.shape: zeros = xla_client.ops.Broadcast( xla_client.ops.Constant(ctx.builder, np.array(0, np.uint32)), aval_out.shape) return [zeros, zeros] def _broadcast(x, aval): return xla_client.ops.BroadcastInDim( x, aval_out.shape, tuple(range(rank - len(aval.shape), rank))) if cuda_prng: return xla.xla_destructure( ctx.builder, cuda_prng.threefry2x32( ctx.builder, (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)), (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)))) else: return xla.xla_destructure( ctx.builder, hip_prng.threefry2x32( ctx.builder, (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)), (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval))))
def _sharded_jit_translation_rule(ctx, avals_in, avals_out, *in_nodes, in_parts, out_parts_thunk, nparts, name, call_jaxpr, local_in_parts, local_out_parts_thunk, local_nparts): subc = xc.XlaBuilder(f"sharded_jit_{name}") # We assume any extra leading in_nodes are constants and replicate them. num_extra_nodes = len(in_nodes) - len(in_parts) assert num_extra_nodes >= 0 in_parts = (None,) * num_extra_nodes + in_parts args = [] for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)): # We use xla.set_sharding instead of xla.with_sharding because inlined calls # shouldn't have shardings set directly on the inputs or outputs. arg = xla.parameter(subc, i, ctx.builder.GetShape(n)) args.append(xla.set_sharding(subc, arg, sharding)) sub_ctx = ctx.replace( builder=subc, name_stack=new_name_stack(wrap_name(name, "sharded_jit"))) out_nodes = xla.jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args) out_parts = out_parts_thunk() assert len(out_parts) == len(out_nodes) out_nodes = [xla.set_sharding(subc, out, sharding) for out, sharding in safe_zip(out_nodes, out_parts)] subc = subc.build(xops.Tuple(subc, out_nodes)) return xla.xla_destructure(ctx.builder, xops.Call(ctx.builder, subc, list(in_nodes)))
def _approx_top_k_tpu_translation(ctx, avals_in, avals_out, operand, *, k, reduction_dimension, recall_target, is_max_k, reduction_input_size_override, aggregate_to_topk): c = ctx.builder op_shape = c.get_shape(operand) if not op_shape.is_array(): raise ValueError('operand must be an array, but was {}'.format(op_shape)) op_dims = op_shape.dimensions() op_type = op_shape.element_type() if reduction_dimension < 0: reduction_dimension = len(op_dims) + reduction_dimension comparator = _comparator_builder(operand, op_type, is_max_k) if is_max_k: if dtypes.issubdtype(op_type, np.floating): init_literal = np.array(np.NINF, dtype=op_type) else: init_literal = np.iinfo(op_type).min() else: if dtypes.issubdtype(op_type, np.floating): init_literal = np.array(np.Inf, dtype=op_type) else: init_literal = np.iinfo(op_type).max() iota = xc.ops.Iota(c, xc.Shape.array_shape(np.dtype(np.int32), op_dims), reduction_dimension) init_val = xc.ops.Constant(c, init_literal) init_arg = xc.ops.Constant(c, np.int32(-1)) out = xc.ops.ApproxTopK(c, [operand, iota], [init_val, init_arg], k, reduction_dimension, comparator, recall_target, aggregate_to_topk, reduction_input_size_override) return xla.xla_destructure(c, out)
def _nonzero_translation_rule(c, dims, avals, operands): (vals,), = operands shape = c.get_shape(vals) last_axis = len(shape.dimensions()) - 1 zeros = xops.Broadcast(xb.constant(c, np.zeros((), shape.numpy_dtype())), shape.dimensions()) s32_etype = xc.dtype_to_etype(np.dtype('int32')) nonzero_indicators = xops.ConvertElementType(xops.Ne(vals, zeros), s32_etype) i = core.ShapedArray((), np.dtype('int32')) out_dim = xops.Reduce(c, [nonzero_indicators], [xb.constant(c, np.array(0, np.dtype('int32')))], xla.primitive_subcomputation(lax.add_p, i, i), (last_axis,)) c.get_shape(out_dim) # xla type checking subc = xb.make_computation_builder("sort_gt_comparator") params = [xb.parameter(subc, i, xc.Shape.array_shape(s32_etype, ())) for i in range(4)] comparator = subc.build(xops.Gt(params[0], params[1])) iota_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, shape.dimensions()) ans = xops.Sort(c, [nonzero_indicators, xops.Iota(c, iota_shape, last_axis)], is_stable=True, comparator=comparator) _, out_val = xla.xla_destructure(c, ans) c.get_shape(out_val) # xla type checking return [[out_dim], [out_val]]
def _pdot_translation_rule(c, x, y, *, axis_name, pos_contract, pos_batch, axis_env, platform): assert axis_name local_out = lax._dot_general_translation_rule( c, x, y, dimension_numbers=[pos_contract, pos_batch], precision=None) out_tup = xla.parallel_translations[psum_p]( c, local_out, axis_name=axis_name, axis_index_groups=None, axis_env=axis_env, platform=platform) out, = xla.xla_destructure(c, out_tup) return out
def _reduce_window_translation_rule(ctx, avals_in, avals_out, *args, jaxpr, consts, window_dimensions, window_strides, padding, base_dilation, window_dilation): operands, init_values = util.split_list(args, [len(args) // 2]) xla_computation = lax._reduction_computation(ctx, jaxpr, consts, init_values, singleton=False) return xla.xla_destructure( ctx.builder, xops.ReduceWindowWithGeneralPadding(operands, init_values, xla_computation, window_dimensions, window_strides, base_dilation, window_dilation, padding))
def _approx_top_k_fallback_translation(ctx, avals_in, avals_out, operand, *, k, reduction_dimension, recall_target, is_max_k, reduction_input_size_override, aggregate_to_topk): c = ctx.builder op_shape = c.get_shape(operand) if not op_shape.is_array(): raise ValueError( 'operand must be an array, but was {}'.format(op_shape)) op_dims = op_shape.dimensions() op_type = op_shape.element_type() if reduction_dimension < 0: reduction_dimension = len(op_dims) + reduction_dimension comparator = _comparator_builder(op_type, is_max_k) iota = xc.ops.Iota(c, xc.Shape.array_shape(np.dtype(np.int32), op_dims), reduction_dimension) if xc._version >= 60: init_val_literal = _get_init_val_literal(op_type, is_max_k) init_val = xc.ops.Constant(c, init_val_literal) init_arg = xc.ops.Constant(c, np.int32(-1)) out = xc.ops.ApproxTopKFallback(c, [operand, iota], [init_val, init_arg], k, reduction_dimension, comparator, recall_target, aggregate_to_topk, reduction_input_size_override) return xla.xla_destructure(c, out) else: val_arg = xc.ops.Sort(c, [operand, iota], comparator, reduction_dimension) vals = xc.ops.GetTupleElement(val_arg, 0) args = xc.ops.GetTupleElement(val_arg, 1) sliced_vals = xc.ops.SliceInDim( vals, 0, avals_out[0].shape[reduction_dimension], 1, reduction_dimension) sliced_args = xc.ops.SliceInDim( args, 0, avals_out[0].shape[reduction_dimension], 1, reduction_dimension) return sliced_vals, sliced_args
def _approx_top_k_fallback_translation(ctx, avals_in, avals_out, operand, *, k, reduction_dimension, recall_target, is_max_k, reduction_input_size_override, aggregate_to_topk): c = ctx.builder op_shape = c.get_shape(operand) if not op_shape.is_array(): raise ValueError(f'operand must be an array, but was {op_shape}') op_dims = op_shape.dimensions() op_type = op_shape.element_type() if reduction_dimension < 0: reduction_dimension = len(op_dims) + reduction_dimension comparator = _comparator_builder(op_type, is_max_k) iota = xc.ops.Iota(c, xc.Shape.array_shape(np.dtype(np.int32), op_dims), reduction_dimension) init_val_literal = _get_init_val_literal(op_type, is_max_k) init_val = xc.ops.Constant(c, init_val_literal) init_arg = xc.ops.Constant(c, np.int32(-1)) out = xc.ops.ApproxTopKFallback(c, [operand, iota], [init_val, init_arg], k, reduction_dimension, comparator, recall_target, aggregate_to_topk, reduction_input_size_override) return xla.xla_destructure(c, out)