Exemple #1
0
def _select_and_gather_add_transpose(t, tangents, operand, *, select_prim,
                                     window_dimensions, window_strides,
                                     padding, base_dilation, window_dilation):
    assert select_prim in (lax.le_p, lax.ge_p)
    assert (ad.is_undefined_primal(tangents)
            and not ad.is_undefined_primal(operand))
    if any(d != 1 for d in window_dilation):
        msg = (
            "VJP not implemented for select_and_gather (MaxPool) with window "
            "dilation, got window_dilation={}.")
        raise NotImplementedError(msg.format(window_dilation))
    if type(t) is ad_util.Zero:
        return [ad_util.Zero(tangents.aval), None]
    has_base_dilation = any(d != 1 for d in base_dilation)
    if has_base_dilation:
        select_identity = (lax._get_max_identity if select_prim is lax.ge_p
                           else lax._get_min_identity)
        operand = lax.pad(operand, select_identity(operand.dtype),
                          tuple((0, 0, d - 1) for d in base_dilation))
    result = _select_and_scatter_add(t, operand, select_prim,
                                     window_dimensions, window_strides,
                                     padding)
    if has_base_dilation:
        result = slicing.slice(result, (0, ) * len(result.shape), result.shape,
                               base_dilation)
    return [result, None]
Exemple #2
0
def _select_and_scatter_add_transpose(
    t, source, operand, *, select_prim, window_dimensions, window_strides,
    padding):
  assert ad.is_undefined_primal(source) and not ad.is_undefined_primal(operand)
  if type(t) is ad_util.Zero:
    return [ad_util.Zero(source.aval), None]
  ones = (1,) * len(window_dimensions)
  source_t = _select_and_gather_add(t, operand, select_prim, window_dimensions,
                                    window_strides, padding, ones, ones)
  return [source_t, None]
Exemple #3
0
 def transposed(*args):
   in_primals, out_cts = tree_unflatten(treedef, args)
   in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else
               pe.PartialVal.known(x) for x in in_primals]
   primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
   t_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(primal_fun, in_pvals, False)
   dummy_args = [ad.UndefinedPrimal(v.aval) for v in t_jaxpr.invars]
   in_cts = ad.backward_pass(t_jaxpr, reduce_axes, False, consts, dummy_args,
                             out_cts)
   in_cts_ = iter(in_cts)
   in_cts = [next(in_cts_) if ad.is_undefined_primal(x)
             else ad_util.Zero(x.aval) for x in in_primals]
   assert next(in_cts_, None) is None
   in_cts, cell.treedef = tree_flatten(in_cts)
   return in_cts