コード例 #1
0
ファイル: common.py プロジェクト: xueeinstein/jax
def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable], in_tree,
                                             in_avals, primitive_name: str):
    # When staging the branches of a conditional into jaxprs, constants are
    # extracted from each branch and converted to jaxpr arguments. To use the
    # staged jaxprs as the branches to a conditional *primitive*, we need for
    # their (input) signatures to match. This function "joins" the staged jaxprs:
    # for each one, it makes another that accepts *all* constants, but only uses
    # those that it needs (dropping the rest).

    jaxprs, all_consts, all_out_trees = \
        unzip3(_initial_style_open_jaxpr(fun, in_tree, in_avals, primitive_name)
               for fun in funs)

    newvar = core.gensym(jaxprs, suffix='_')
    all_const_avals = [map(_abstractify, consts) for consts in all_consts]
    unused_const_vars = [
        map(newvar, const_avals) for const_avals in all_const_avals
    ]

    def pad_jaxpr_constvars(i, jaxpr):
        prefix = util.concatenate(unused_const_vars[:i])
        suffix = util.concatenate(unused_const_vars[i + 1:])
        constvars = [*prefix, *jaxpr.constvars, *suffix]
        return jaxpr.replace(constvars=constvars)

    consts = util.concatenate(all_consts)
    jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
    closed_jaxprs = [
        core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
        for jaxpr in jaxprs
    ]
    return closed_jaxprs, consts, all_out_trees
コード例 #2
0
ファイル: batching.py プロジェクト: wayfeng/jax
    def post_process_map(self, call_primitive, out_tracers, params):
        vals, dims, srcs = unzip3(
            (t.val, t.batch_dim, t.source_info) for t in out_tracers)
        main = self.main

        def both_mapped(in_out_axis, d):
            return in_out_axis is not None and d is not not_mapped

        def todo(vals):
            trace = main.with_cur_sublevel()
            return [
                BatchTracer(trace, v,
                            d + 1 if both_mapped(oa, d) and oa <= d else d, s)
                for v, d, oa, s in zip(vals, dims, params['out_axes_thunk'](),
                                       srcs)
            ]

        if call_primitive.map_primitive:

            def out_axes_transform(out_axes):
                return tuple(out_axis + 1 if both_mapped(out_axis, d)
                             and d < out_axis else out_axis
                             for out_axis, d in zip(out_axes, dims))

            todo = (todo, out_axes_transform)
        return vals, todo
コード例 #3
0
ファイル: batching.py プロジェクト: xueeinstein/jax
 def post_process_custom_vjp_call(self, out_tracers, _):
   vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
                             for t in out_tracers)
   main = self.main
   def todo(vals):
     trace = main.with_cur_sublevel()
     return map(partial(BatchTracer, trace), vals, dims, srcs)
   return vals, todo
コード例 #4
0
def _map_coordinates(input, coordinates, order, mode, cval):
    input = jnp.asarray(input)
    coordinates = [jnp.asarray(c) for c in coordinates]
    cval = jnp.asarray(cval, input.dtype)

    if len(coordinates) != input.ndim:
        raise ValueError(
            'coordinates must be a sequence of length input.ndim, but '
            '{} != {}'.format(len(coordinates), input.ndim))

    index_fixer = _INDEX_FIXERS.get(mode)
    if index_fixer is None:
        raise NotImplementedError(
            'jax.scipy.ndimage.map_coordinates does not yet support mode {}. '
            'Currently supported modes are {}.'.format(mode,
                                                       set(_INDEX_FIXERS)))

    if mode == 'constant':
        is_valid = lambda index, size: (0 <= index) & (index < size)
    else:
        is_valid = lambda index, size: True

    if order == 0:
        interp_fun = _nearest_indices_and_weights
    elif order == 1:
        interp_fun = _linear_indices_and_weights
    else:
        raise NotImplementedError(
            'jax.scipy.ndimage.map_coordinates currently requires order<=1')

    valid_1d_interpolations = []
    for coordinate, size in zip(coordinates, input.shape):
        interp_nodes = interp_fun(coordinate)
        valid_interp = []
        for index, weight in interp_nodes:
            fixed_index = index_fixer(index, size)
            valid = is_valid(index, size)
            valid_interp.append((fixed_index, valid, weight))
        valid_1d_interpolations.append(valid_interp)

    outputs = []
    for items in itertools.product(*valid_1d_interpolations):
        indices, validities, weights = util.unzip3(items)
        if all(valid is True for valid in validities):
            # fast path
            contribution = input[indices]
        else:
            all_valid = functools.reduce(operator.and_, validities)
            contribution = jnp.where(all_valid, input[indices], cval)
        outputs.append(_nonempty_prod(weights) * contribution)
    result = _nonempty_sum(outputs)
    if jnp.issubdtype(input.dtype, jnp.integer):
        result = _round_half_away_from_zero(result)
    return result.astype(input.dtype)
コード例 #5
0
ファイル: batching.py プロジェクト: xueeinstein/jax
 def post_process_custom_jvp_call(self, out_tracers, jvp_was_run):
   vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
                             for t in out_tracers)
   main = self.main
   def todo(vals):
     trace = main.with_cur_sublevel()
     if jvp_was_run:
       primal_dims, tangent_dims = dims[:len(vals)], dims[len(vals):]
       assert primal_dims == tangent_dims
       primal_srcs = srcs[:len(vals)]
       return map(partial(BatchTracer, trace), vals, primal_dims, primal_srcs)
     else:
       return map(partial(BatchTracer, trace), vals, dims, srcs)
   return vals, todo
コード例 #6
0
def pad(operand, padding_value, padding_config):
  # https://www.tensorflow.org/xla/operation_semantics#pad
  lo, hi, interior = util.unzip3(padding_config)
  # Handle first the positive edge padding and interior
  lo_pos, hi_pos = np.clip(lo, 0, None), np.clip(hi, 0, None)
  outshape = np.add(np.add(np.add(lo_pos, hi_pos), operand.shape),
                     np.multiply(interior, np.subtract(operand.shape, 1)))
  out = np.full(outshape, padding_value, operand.dtype)
  lhs_slices = tuple(_slice(l if l > 0 else 0, -h if h > 0 else None, step)
                     for l, h, step in zip(lo_pos, hi_pos, np.add(1, interior)))
  out[lhs_slices] = operand
  trim_slices = tuple(_slice(-l if l < 0 else 0, h if h < 0 else None)
                     for l, h in zip(lo, hi))
  return out[trim_slices]
コード例 #7
0
ファイル: batching.py プロジェクト: xueeinstein/jax
 def post_process_custom_vjp_call_fwd(self, out_tracers, out_trees):
   vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
                             for t in out_tracers)
   axis_size, = {x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}
   main, trace_type = self.main, self.main.trace_type
   axis_name = self.axis_name
   _, res_tree = out_trees()
   num_res = res_tree.num_leaves
   res_dims, primal_dims = split_list(dims, [num_res])
   _, primal_srcs = split_list(srcs, [num_res])
   def todo(vals):
     trace = main.with_cur_sublevel()
     return map(partial(BatchTracer, trace), vals, primal_dims, primal_srcs)
   def bwd_transform(bwd):
     return batch_custom_vjp_bwd(bwd, axis_name, axis_size, dims, (None,),
                                 trace_type)
   return vals, todo, bwd_transform
コード例 #8
0
ファイル: impl_no_xla.py プロジェクト: matthewfeickert/jax
def _pad(operand, padding_value, *, padding_config,
         _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray):
    low, high, interior = util.unzip3(padding_config)

    # Do only the interior padding first. This is rarely needed.
    if any(i != 0 for _, _, i in padding_config):
        operand = _interior_padding(operand, padding_value, padding_config,
                                    jax2tf._eval_shape(_in_avals[0].shape))

    # Now do the non-negative edge padding. This is the common case, use tf.pad.
    non_negative_padding = [((lo if lo >= 0 else 0), (hi if hi >= 0 else 0))
                            for lo, hi, _ in padding_config]
    operand = tf.pad(operand,
                     non_negative_padding,
                     mode="CONSTANT",
                     constant_values=padding_value)
    # Now the negative edge padding (this is also rare)
    if any(lo < 0 or hi < 0 for lo, hi, _ in padding_config):
        output_shape = jax2tf._eval_shape(_out_aval.shape)
        begins = [(-lo if lo < 0 else 0) for lo, _, _ in padding_config]
        operand = tf.slice(operand, begins, output_shape)

    return operand