def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable], in_tree, in_avals, primitive_name: str): # When staging the branches of a conditional into jaxprs, constants are # extracted from each branch and converted to jaxpr arguments. To use the # staged jaxprs as the branches to a conditional *primitive*, we need for # their (input) signatures to match. This function "joins" the staged jaxprs: # for each one, it makes another that accepts *all* constants, but only uses # those that it needs (dropping the rest). jaxprs, all_consts, all_out_trees = \ unzip3(_initial_style_open_jaxpr(fun, in_tree, in_avals, primitive_name) for fun in funs) newvar = core.gensym(jaxprs, suffix='_') all_const_avals = [map(_abstractify, consts) for consts in all_consts] unused_const_vars = [ map(newvar, const_avals) for const_avals in all_const_avals ] def pad_jaxpr_constvars(i, jaxpr): prefix = util.concatenate(unused_const_vars[:i]) suffix = util.concatenate(unused_const_vars[i + 1:]) constvars = [*prefix, *jaxpr.constvars, *suffix] return jaxpr.replace(constvars=constvars) consts = util.concatenate(all_consts) jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)] closed_jaxprs = [ core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) for jaxpr in jaxprs ] return closed_jaxprs, consts, all_out_trees
def post_process_map(self, call_primitive, out_tracers, params): vals, dims, srcs = unzip3( (t.val, t.batch_dim, t.source_info) for t in out_tracers) main = self.main def both_mapped(in_out_axis, d): return in_out_axis is not None and d is not not_mapped def todo(vals): trace = main.with_cur_sublevel() return [ BatchTracer(trace, v, d + 1 if both_mapped(oa, d) and oa <= d else d, s) for v, d, oa, s in zip(vals, dims, params['out_axes_thunk'](), srcs) ] if call_primitive.map_primitive: def out_axes_transform(out_axes): return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis for out_axis, d in zip(out_axes, dims)) todo = (todo, out_axes_transform) return vals, todo
def post_process_custom_vjp_call(self, out_tracers, _): vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info) for t in out_tracers) main = self.main def todo(vals): trace = main.with_cur_sublevel() return map(partial(BatchTracer, trace), vals, dims, srcs) return vals, todo
def _map_coordinates(input, coordinates, order, mode, cval): input = jnp.asarray(input) coordinates = [jnp.asarray(c) for c in coordinates] cval = jnp.asarray(cval, input.dtype) if len(coordinates) != input.ndim: raise ValueError( 'coordinates must be a sequence of length input.ndim, but ' '{} != {}'.format(len(coordinates), input.ndim)) index_fixer = _INDEX_FIXERS.get(mode) if index_fixer is None: raise NotImplementedError( 'jax.scipy.ndimage.map_coordinates does not yet support mode {}. ' 'Currently supported modes are {}.'.format(mode, set(_INDEX_FIXERS))) if mode == 'constant': is_valid = lambda index, size: (0 <= index) & (index < size) else: is_valid = lambda index, size: True if order == 0: interp_fun = _nearest_indices_and_weights elif order == 1: interp_fun = _linear_indices_and_weights else: raise NotImplementedError( 'jax.scipy.ndimage.map_coordinates currently requires order<=1') valid_1d_interpolations = [] for coordinate, size in zip(coordinates, input.shape): interp_nodes = interp_fun(coordinate) valid_interp = [] for index, weight in interp_nodes: fixed_index = index_fixer(index, size) valid = is_valid(index, size) valid_interp.append((fixed_index, valid, weight)) valid_1d_interpolations.append(valid_interp) outputs = [] for items in itertools.product(*valid_1d_interpolations): indices, validities, weights = util.unzip3(items) if all(valid is True for valid in validities): # fast path contribution = input[indices] else: all_valid = functools.reduce(operator.and_, validities) contribution = jnp.where(all_valid, input[indices], cval) outputs.append(_nonempty_prod(weights) * contribution) result = _nonempty_sum(outputs) if jnp.issubdtype(input.dtype, jnp.integer): result = _round_half_away_from_zero(result) return result.astype(input.dtype)
def post_process_custom_jvp_call(self, out_tracers, jvp_was_run): vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info) for t in out_tracers) main = self.main def todo(vals): trace = main.with_cur_sublevel() if jvp_was_run: primal_dims, tangent_dims = dims[:len(vals)], dims[len(vals):] assert primal_dims == tangent_dims primal_srcs = srcs[:len(vals)] return map(partial(BatchTracer, trace), vals, primal_dims, primal_srcs) else: return map(partial(BatchTracer, trace), vals, dims, srcs) return vals, todo
def pad(operand, padding_value, padding_config): # https://www.tensorflow.org/xla/operation_semantics#pad lo, hi, interior = util.unzip3(padding_config) # Handle first the positive edge padding and interior lo_pos, hi_pos = np.clip(lo, 0, None), np.clip(hi, 0, None) outshape = np.add(np.add(np.add(lo_pos, hi_pos), operand.shape), np.multiply(interior, np.subtract(operand.shape, 1))) out = np.full(outshape, padding_value, operand.dtype) lhs_slices = tuple(_slice(l if l > 0 else 0, -h if h > 0 else None, step) for l, h, step in zip(lo_pos, hi_pos, np.add(1, interior))) out[lhs_slices] = operand trim_slices = tuple(_slice(-l if l < 0 else 0, h if h < 0 else None) for l, h in zip(lo, hi)) return out[trim_slices]
def post_process_custom_vjp_call_fwd(self, out_tracers, out_trees): vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info) for t in out_tracers) axis_size, = {x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped} main, trace_type = self.main, self.main.trace_type axis_name = self.axis_name _, res_tree = out_trees() num_res = res_tree.num_leaves res_dims, primal_dims = split_list(dims, [num_res]) _, primal_srcs = split_list(srcs, [num_res]) def todo(vals): trace = main.with_cur_sublevel() return map(partial(BatchTracer, trace), vals, primal_dims, primal_srcs) def bwd_transform(bwd): return batch_custom_vjp_bwd(bwd, axis_name, axis_size, dims, (None,), trace_type) return vals, todo, bwd_transform
def _pad(operand, padding_value, *, padding_config, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): low, high, interior = util.unzip3(padding_config) # Do only the interior padding first. This is rarely needed. if any(i != 0 for _, _, i in padding_config): operand = _interior_padding(operand, padding_value, padding_config, jax2tf._eval_shape(_in_avals[0].shape)) # Now do the non-negative edge padding. This is the common case, use tf.pad. non_negative_padding = [((lo if lo >= 0 else 0), (hi if hi >= 0 else 0)) for lo, hi, _ in padding_config] operand = tf.pad(operand, non_negative_padding, mode="CONSTANT", constant_values=padding_value) # Now the negative edge padding (this is also rare) if any(lo < 0 or hi < 0 for lo, hi, _ in padding_config): output_shape = jax2tf._eval_shape(_out_aval.shape) begins = [(-lo if lo < 0 else 0) for lo, _, _ in padding_config] operand = tf.slice(operand, begins, output_shape) return operand